feat: add distributed mode (#9124)

* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-30 00:47:27 +02:00
committed by GitHub
parent 4c870288d9
commit 59108fbe32
389 changed files with 276305 additions and 246521 deletions

View File

@@ -406,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string {
}
// Parse the response to get avatar URL
var userInfo map[string]interface{}
var userInfo map[string]any
body, err := io.ReadAll(resp.Body)
if err != nil {
return ""

View File

@@ -3,7 +3,7 @@ package main
import (
"context"
"fmt"
"math/rand"
"math/rand/v2"
"strings"
"time"
)
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
generator := NewSyntheticDataGenerator()
// Generate a random number of synthetic models (1-3)
numModels := generator.rand.Intn(3) + 1
numModels := generator.rand.IntN(3) + 1
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
var models []ProcessedModel
for i := 0; i < numModels; i++ {
for i := range numModels {
model := generator.GenerateProcessedModel()
models = append(models, model)
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
// NewSyntheticDataGenerator creates a new synthetic data generator
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
return &SyntheticDataGenerator{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
}
}
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
fileTypes := []string{"model", "readme", "other"}
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
var path string
var isReadme bool
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
return ProcessedModelFile{
Path: path,
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
SHA256: g.randomSHA256(),
IsReadme: isReadme,
FileType: fileType,
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
author := authors[g.rand.Intn(len(authors))]
modelName := modelNames[g.rand.Intn(len(modelNames))]
author := authors[g.rand.IntN(len(authors))]
modelName := modelNames[g.rand.IntN(len(modelNames))]
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
// Generate files
numFiles := g.rand.Intn(5) + 2 // 2-6 files
numFiles := g.rand.IntN(5) + 2 // 2-6 files
files := make([]ProcessedModelFile, numFiles)
// Ensure at least one model file and one readme
hasModelFile := false
hasReadme := false
for i := 0; i < numFiles; i++ {
for i := range numFiles {
files[i] = g.GenerateProcessedModelFile()
if files[i].FileType == "model" {
hasModelFile = true
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
// Generate sample metadata
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
license := licenses[g.rand.Intn(len(licenses))]
license := licenses[g.rand.IntN(len(licenses))]
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
numTags := g.rand.Intn(4) + 3 // 3-6 tags
numTags := g.rand.IntN(4) + 3 // 3-6 tags
tags := make([]string, numTags)
for i := 0; i < numTags; i++ {
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
for i := range numTags {
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
}
// Remove duplicates
tags = g.removeDuplicates(tags)
// Optionally include icon (50% chance)
icon := ""
if g.rand.Intn(2) == 0 {
if g.rand.IntN(2) == 0 {
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
}
return ProcessedModel{
ModelID: modelID,
Author: author,
Downloads: g.rand.Intn(1000000) + 1000,
Downloads: g.rand.IntN(1000000) + 1000,
LastModified: g.randomDate(),
Files: files,
PreferredModelFile: preferredModelFile,
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[g.rand.Intn(len(charset))]
b[i] = charset[g.rand.IntN(len(charset))]
}
return string(b)
}
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
const charset = "0123456789abcdef"
b := make([]byte, 64)
for i := range b {
b[i] = charset[g.rand.Intn(len(charset))]
b[i] = charset[g.rand.IntN(len(charset))]
}
return string(b)
}
func (g *SyntheticDataGenerator) randomDate() string {
now := time.Now()
daysAgo := g.rand.Intn(365) // Random date within last year
daysAgo := g.rand.IntN(365) // Random date within last year
pastDate := now.AddDate(0, 0, -daysAgo)
return pastDate.Format("2006-01-02T15:04:05.000Z")
}
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
}
return templates[g.rand.Intn(len(templates))]
return templates[g.rand.IntN(len(templates))]
}

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.25.x']
go-version: ['1.26.x']
steps:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
@@ -179,7 +179,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
go-version: ['1.25.x']
go-version: ['1.26.x']
steps:
- name: Clone
uses: actions/checkout@v6

View File

@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
FROM requirements-drivers AS build-requirements
ARG GO_VERSION=1.25.4
ARG GO_VERSION=1.26.0
ARG CMAKE_VERSION=3.31.10
ARG CMAKE_FROM_SOURCE=false
ARG TARGETARCH
@@ -319,7 +319,6 @@ COPY ./.git ./.git
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
COPY ./pkg/grpc ./pkg/grpc
COPY ./pkg/utils ./pkg/utils
COPY ./pkg/langchain ./pkg/langchain
RUN ls -l ./
RUN make protogen-go

View File

@@ -154,6 +154,7 @@ For older news and full release notes, see [GitHub Releases](https://github.com/
- [Object Detection](https://localai.io/features/object-detection/)
- [Reranker API](https://localai.io/features/reranker/)
- [P2P Inferencing](https://localai.io/features/distribute/)
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images

View File

@@ -51,6 +51,7 @@ service Backend {
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
}
// Define the empty request
@@ -676,3 +677,4 @@ message QuantizationProgressUpdate {
message QuantizationStopRequest {
string job_id = 1;
}

View File

@@ -22,8 +22,10 @@
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/security/server_credentials.h>
#include <regex>
#include <atomic>
#include <cstdlib>
#include <mutex>
#include <signal.h>
#include <thread>
@@ -37,6 +39,47 @@ using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::Status;
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
// requests without a matching "authorization: Bearer <token>" metadata header.
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
public:
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
bool IsBlocking() const override { return false; }
grpc::Status Process(const InputMetadata& auth_metadata,
grpc::AuthContext* /*context*/,
OutputMetadata* /*consumed_auth_metadata*/,
OutputMetadata* /*response_metadata*/) override {
auto it = auth_metadata.find("authorization");
if (it != auth_metadata.end()) {
std::string expected = "Bearer " + token_;
std::string got(it->second.data(), it->second.size());
// Constant-time comparison
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
return grpc::Status::OK;
}
}
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
}
private:
std::string token_;
// Minimal constant-time comparison (avoids OpenSSL dependency)
static int ct_memcmp(const void* a, const void* b, size_t n) {
const unsigned char* pa = static_cast<const unsigned char*>(a);
const unsigned char* pb = static_cast<const unsigned char*>(b);
unsigned char result = 0;
for (size_t i = 0; i < n; i++) {
result |= pa[i] ^ pb[i];
}
return result;
}
};
// END LocalAI
@@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
BackendServiceImpl service(ctx_server);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
std::shared_ptr<grpc::ServerCredentials> creds;
if (auth_token != nullptr && auth_token[0] != '\0') {
creds = grpc::InsecureServerCredentials();
creds->SetAuthMetadataProcessor(
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
} else {
creds = grpc::InsecureServerCredentials();
}
builder.AddListeningPort(server_address, creds);
builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart());
// run the HTTP server in a thread - see comment below
std::thread t([&]()

View File

@@ -106,10 +106,10 @@ func TestLoadModel(t *testing.T) {
defer conn.Close()
client := pb.NewBackendClient(conn)
// Get base directory from main model file for relative paths
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: mainModelPath,
ModelPath: modelDir,
@@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { os.RemoveAll(tmpDir) })
outputFile := filepath.Join(tmpDir, "output.wav")

View File

@@ -11,7 +11,7 @@ import (
)
var (
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
)
@@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
var textEncoderModel, ditModel, vaeModel string
for _, oo := range opts.Options {
parts := strings.SplitN(oo, ":", 2)
if len(parts) != 2 {
key, value, found := strings.Cut(oo, ":")
if !found {
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
continue
}
switch parts[0] {
switch key {
case "text_encoder_model":
textEncoderModel = parts[1]
textEncoderModel = value
case "dit_model":
ditModel = parts[1]
ditModel = value
case "vae_model":
vaeModel = parts[1]
vaeModel = value
default:
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
}

View File

@@ -18,7 +18,6 @@ type LLM struct {
draftModel *llama.LLama
}
// Free releases GPU resources and frees the llama model
// This should be called when the model is being unloaded to properly release VRAM
func (llm *LLM) Free() error {

View File

@@ -1,5 +1,4 @@
//go:build debug
// +build debug
package main

View File

@@ -1,5 +1,4 @@
//go:build !debug
// +build !debug
package main

View File

@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot float32
for i := 0; i < len(k1); i++ {
for i := range len(k1) {
dot += k1[i] * k2[i]
}
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot, mag2 float64
for i := 0; i < len(k1); i++ {
for i := range len(k1) {
dot += float64(k1[i] * k2[i])
mag2 += float64(k2[i] * k2[i])
}

View File

@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
// to one-shot (only difference is resampler batch boundaries).
var maxDiff float64
var sumDiffSq float64
for i := 0; i < minLen; i++ {
for i := range minLen {
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
if diff > maxDiff {
maxDiff = diff
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
var persistentMaxDiff, freshMaxDiff float64
for i := 0; i < minLen; i++ {
for i := range minLen {
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
if pd > persistentMaxDiff {
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
Expect(stddev / mean).To(BeNumerically("<", 0.15),
Expect(stddev/mean).To(BeNumerically("<", 0.15),
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
// Also check frequency is correct
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
// Every sample must be identical — the resampler is deterministic
var maxDiff float64
for i := 0; i < len(oneShot); i++ {
for i := range len(oneShot) {
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
if diff > maxDiff {
maxDiff = diff
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
copy(hdr[8:12], "WAVE")
copy(hdr[12:16], "fmt ")
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
copy(hdr[36:40], "data")
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
)
pcm := make([]byte, toneNumSamples*2)
for i := 0; i < toneNumSamples; i++ {
for i := range toneNumSamples {
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
}

View File

@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { os.RemoveAll(tmpDir) })
// Download sample audio — JFK "ask not what your country can do for you" clip
audioFile := filepath.Join(tmpDir, "sample.wav")

View File

@@ -19,6 +19,10 @@ import tempfile
import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from acestep.inference import (
GenerationParams,
GenerationConfig,
@@ -444,6 +448,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024),
("grpc.max_receive_message_length", 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import tempfile
def is_float(s):
@@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -0,0 +1,78 @@
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
a valid Bearer token in the 'authorization' metadata header are rejected with
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
performed (backward compatible).
"""
import hmac
import os
import grpc
class _AbortHandler(grpc.RpcMethodHandler):
"""A method handler that immediately aborts with UNAUTHENTICATED."""
def __init__(self):
self.request_streaming = False
self.response_streaming = False
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = self._abort
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
@staticmethod
def _abort(request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
class TokenAuthInterceptor(grpc.ServerInterceptor):
"""Sync gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
self._abort_handler = _AbortHandler()
def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return self._abort_handler
return continuation(handler_call_details)
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
"""Async gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
async def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return _AbortHandler()
return await continuation(handler_call_details)
def get_auth_interceptors(*, aio: bool = False):
"""Return a list of gRPC interceptors for bearer token auth.
Args:
aio: If True, return async-compatible interceptors for grpc.aio.server().
If False (default), return sync interceptors for grpc.server().
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
"""
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
if not token:
return []
if aio:
return [AsyncTokenAuthInterceptor(token)]
return [TokenAuthInterceptor(token)]

View File

@@ -15,6 +15,10 @@ import torch
from TTS.api import TTS
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -93,7 +97,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -22,6 +22,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
# Import dynamic loader for pipeline discovery
from diffusers_dynamic_loader import (
@@ -1042,7 +1046,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -15,6 +15,10 @@ import torch
import soundfile as sf
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
@@ -165,6 +169,8 @@ def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
]
,
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -14,6 +14,10 @@ import torch
from faster_whisper import WhisperModel
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -70,7 +74,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -19,6 +19,10 @@ import numpy as np
import json
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
@@ -424,6 +428,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ from kittentts import KittenTTS
import soundfile as sf
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -77,7 +81,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -16,6 +16,10 @@ from kokoro import KPipeline
import soundfile as sf
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -84,7 +88,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -17,6 +17,10 @@ import time
from concurrent import futures
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2
import backend_pb2_grpc
@@ -398,7 +402,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -15,6 +15,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_audio.tts.utils import load_model
import soundfile as sf
import numpy as np
@@ -436,7 +440,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address

View File

@@ -23,6 +23,10 @@ import tempfile
from typing import List
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2
import backend_pb2_grpc
@@ -468,6 +472,8 @@ async def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_vlm import load, generate, stream_generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config, load_image
@@ -446,7 +450,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_lm import load, generate, stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
@@ -421,7 +425,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address

View File

@@ -17,6 +17,10 @@ from moonshine_voice import (
)
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -128,7 +132,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -14,6 +14,10 @@ import torch
import nemo.collections.asr as nemo_asr
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
@@ -119,7 +123,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
import soundfile as sf
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
@@ -130,7 +134,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import outetts
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -116,7 +120,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ import torch
from pocket_tts import TTSModel
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
@@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -14,6 +14,10 @@ import torch
from qwen_asr import Qwen3ASRModel
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
@@ -184,7 +188,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -23,6 +23,10 @@ import hashlib
import pickle
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
@@ -900,6 +904,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from rerankers import Reranker
@@ -97,7 +101,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -13,6 +13,10 @@ import base64
import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import requests
@@ -139,7 +143,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -16,6 +16,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import torch
import torch.cuda
@@ -532,7 +536,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address

View File

@@ -17,6 +17,10 @@ import uuid
from concurrent import futures
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2
import backend_pb2_grpc
@@ -832,6 +836,8 @@ def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)

View File

@@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
@@ -724,7 +728,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -27,6 +27,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.outputs import OmniRequestOutput
@@ -650,7 +654,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
@@ -338,7 +342,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address

View File

@@ -18,6 +18,10 @@ import backend_pb2_grpc
import torch
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
@@ -297,7 +301,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -13,6 +13,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -137,7 +141,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()

View File

@@ -3,7 +3,7 @@ package application
import (
"time"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/xlog"
)
@@ -22,13 +22,23 @@ func (a *Application) RestartAgentJobService() error {
}
// Create new service instance
agentJobService := services.NewAgentJobService(
agentJobService := agentpool.NewAgentJobService(
a.ApplicationConfig(),
a.ModelLoader(),
a.ModelConfigLoader(),
a.TemplatesEvaluator(),
)
// Re-apply distributed wiring if available (matches startup.go logic)
if d := a.Distributed(); d != nil {
if d.Dispatcher != nil {
agentJobService.SetDistributedBackends(d.Dispatcher)
}
if d.JobStore != nil {
agentJobService.SetDistributedJobStore(d.JobStore)
}
}
// Start the service
err := agentJobService.Start(a.ApplicationConfig().Context)
if err != nil {

View File

@@ -2,12 +2,16 @@ package application
import (
"context"
"math/rand/v2"
"sync"
"sync/atomic"
"time"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
@@ -20,9 +24,9 @@ type Application struct {
applicationConfig *config.ApplicationConfig
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
templatesEvaluator *templates.Evaluator
galleryService *services.GalleryService
agentJobService *services.AgentJobService
agentPoolService atomic.Pointer[services.AgentPoolService]
galleryService *galleryop.GalleryService
agentJobService *agentpool.AgentJobService
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
authDB *gorm.DB
watchdogMutex sync.Mutex
watchdogStop chan bool
@@ -30,6 +34,9 @@ type Application struct {
p2pCtx context.Context
p2pCancel context.CancelFunc
agentJobMutex sync.Mutex
// Distributed mode services (nil when not in distributed mode)
distributed *DistributedServices
}
func newApplication(appConfig *config.ApplicationConfig) *Application {
@@ -64,15 +71,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator {
return a.templatesEvaluator
}
func (a *Application) GalleryService() *services.GalleryService {
func (a *Application) GalleryService() *galleryop.GalleryService {
return a.galleryService
}
func (a *Application) AgentJobService() *services.AgentJobService {
func (a *Application) AgentJobService() *agentpool.AgentJobService {
return a.agentJobService
}
func (a *Application) AgentPoolService() *services.AgentPoolService {
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
return a.agentPoolService.Load()
}
@@ -86,8 +93,53 @@ func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig
}
// Distributed returns the distributed services, or nil if not in distributed mode.
func (a *Application) Distributed() *DistributedServices {
return a.distributed
}
// IsDistributed returns true if the application is running in distributed mode.
func (a *Application) IsDistributed() bool {
return a.distributed != nil
}
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
// This prevents the agent pool from failing during startup when workers haven't connected yet.
func (a *Application) waitForHealthyWorker() {
maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault()
const basePoll = 2 * time.Second
xlog.Info("Waiting for at least one healthy backend worker before starting agent pool")
deadline := time.Now().Add(maxWait)
for time.Now().Before(deadline) {
registered, err := a.distributed.Registry.List(context.Background())
if err == nil {
for _, n := range registered {
if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy {
xlog.Info("Healthy backend worker found", "node", n.Name)
return
}
}
}
// Add 0-1s jitter to prevent thundering-herd on the node registry
jitter := time.Duration(rand.Int64N(int64(time.Second)))
select {
case <-a.applicationConfig.Context.Done():
return
case <-time.After(basePoll + jitter):
}
}
xlog.Warn("No healthy backend worker found after waiting, proceeding anyway")
}
// InstanceID returns the unique identifier for this frontend instance.
func (a *Application) InstanceID() string {
return a.applicationConfig.Distributed.InstanceID
}
func (a *Application) start() error {
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
if err != nil {
return err
@@ -95,19 +147,14 @@ func (a *Application) start() error {
a.galleryService = galleryService
// Initialize agent job service
agentJobService := services.NewAgentJobService(
// Initialize agent job service (Start() is deferred to after distributed wiring)
agentJobService := agentpool.NewAgentJobService(
a.ApplicationConfig(),
a.ModelLoader(),
a.ModelConfigLoader(),
a.TemplatesEvaluator(),
)
err = agentJobService.Start(a.ApplicationConfig().Context)
if err != nil {
return err
}
a.agentJobService = agentJobService
return nil
@@ -120,27 +167,56 @@ func (a *Application) StartAgentPool() {
if !a.applicationConfig.AgentPool.Enabled {
return
}
aps, err := services.NewAgentPoolService(a.applicationConfig)
// Build options struct from available dependencies
opts := agentpool.AgentPoolOptions{
AuthDB: a.authDB,
}
if d := a.Distributed(); d != nil {
if d.DistStores != nil && d.DistStores.Skills != nil {
opts.SkillStore = d.DistStores.Skills
}
opts.NATSClient = d.Nats
opts.EventBridge = d.AgentBridge
opts.AgentStore = d.AgentStore
}
aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts)
if err != nil {
xlog.Error("Failed to create agent pool service", "error", err)
return
}
if a.authDB != nil {
aps.SetAuthDB(a.authDB)
// Wire distributed mode components
if d := a.Distributed(); d != nil {
// Wait for at least one healthy backend worker before starting the agent pool.
// Collections initialization calls embeddings which require a worker.
if d.Registry != nil {
a.waitForHealthyWorker()
}
}
if err := aps.Start(a.applicationConfig.Context); err != nil {
xlog.Error("Failed to start agent pool", "error", err)
return
}
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
usm := services.NewUserServicesManager(
usm := agentpool.NewUserServicesManager(
aps.UserStorage(),
a.applicationConfig,
a.modelLoader,
a.backendLoader,
a.templatesEvaluator,
)
// Wire distributed backends to per-user job services
if a.agentJobService != nil {
if d := a.agentJobService.Dispatcher(); d != nil {
usm.SetJobDispatcher(d)
}
if s := a.agentJobService.DBStore(); s != nil {
usm.SetJobDBStore(s)
}
}
aps.SetUserServicesManager(usm)
a.agentPoolService.Store(aps)

View File

@@ -0,0 +1,267 @@
package application
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/core/services/distributed"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
// DistributedServices holds all services initialized for distributed mode.
type DistributedServices struct {
Nats *messaging.Client
Store storage.ObjectStore
Registry *nodes.NodeRegistry
Router *nodes.SmartRouter
Health *nodes.HealthMonitor
JobStore *jobs.JobStore
Dispatcher *jobs.Dispatcher
AgentStore *agents.AgentStore
AgentBridge *agents.EventBridge
DistStores *distributed.Stores
FileMgr *storage.FileManager
FileStager nodes.FileStager
ModelAdapter *nodes.ModelRouterAdapter
Unloader *nodes.RemoteUnloaderAdapter
shutdownOnce sync.Once
}
// Shutdown stops all distributed services in reverse initialization order.
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
func (ds *DistributedServices) Shutdown() {
if ds == nil {
return
}
ds.shutdownOnce.Do(func() {
if ds.Health != nil {
ds.Health.Stop()
}
if ds.Dispatcher != nil {
ds.Dispatcher.Stop()
}
if closer, ok := ds.Store.(io.Closer); ok {
closer.Close()
}
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
// when the NATS client is closed below.
if ds.Nats != nil {
ds.Nats.Close()
}
xlog.Info("Distributed services shut down")
})
}
// initDistributed validates distributed mode prerequisites and initializes
// NATS, object storage, node registry, and instance identity.
// Returns nil if distributed mode is not enabled.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
if !cfg.Distributed.Enabled {
return nil, nil
}
xlog.Info("Distributed mode enabled — validating prerequisites")
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
if err := cfg.Distributed.Validate(); err != nil {
return nil, err
}
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
if !cfg.Auth.Enabled {
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
}
if !isPostgresURL(cfg.Auth.DatabaseURL) {
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
}
// Generate instance ID if not set
if cfg.Distributed.InstanceID == "" {
cfg.Distributed.InstanceID = uuid.New().String()
}
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
// Ensure NATS is closed if any subsequent initialization step fails.
success := false
defer func() {
if !success {
natsClient.Close()
}
}()
// Initialize object storage
var store storage.ObjectStore
if cfg.Distributed.StorageURL != "" {
if cfg.Distributed.StorageBucket == "" {
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
}
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
Endpoint: cfg.Distributed.StorageURL,
Region: cfg.Distributed.StorageRegion,
Bucket: cfg.Distributed.StorageBucket,
AccessKeyID: cfg.Distributed.StorageAccessKey,
SecretAccessKey: cfg.Distributed.StorageSecretKey,
ForcePathStyle: true, // required for MinIO
})
if err != nil {
return nil, fmt.Errorf("initializing S3 storage: %w", err)
}
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
store = s3Store
} else {
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
if err != nil {
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
}
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
store = fsStore
}
// Initialize node registry (requires the auth DB which is PostgreSQL)
if authDB == nil {
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
}
registry, err := nodes.NewNodeRegistry(authDB)
if err != nil {
return nil, fmt.Errorf("initializing node registry: %w", err)
}
xlog.Info("Node registry initialized")
// Collect SmartRouter option values; the router itself is created after all
// dependencies (including FileStager and Unloader) are ready.
var routerAuthToken string
if cfg.Distributed.RegistrationToken != "" {
routerAuthToken = cfg.Distributed.RegistrationToken
}
var routerGalleriesJSON string
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
routerGalleriesJSON = string(galleriesJSON)
}
healthMon := nodes.NewHealthMonitor(registry, authDB,
cfg.Distributed.HealthCheckIntervalOrDefault(),
cfg.Distributed.StaleNodeThresholdOrDefault(),
routerAuthToken,
cfg.Distributed.PerModelHealthCheck,
)
// Initialize job store
jobStore, err := jobs.NewJobStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing job store: %w", err)
}
xlog.Info("Distributed job store initialized")
// Initialize job dispatcher
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
// Initialize agent store
agentStore, err := agents.NewAgentStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing agent store: %w", err)
}
xlog.Info("Distributed agent store initialized")
// Initialize agent event bridge
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
// Start observable persister — captures observable_update events from workers
// (which have no DB access) and persists them to PostgreSQL.
if err := agentBridge.StartObservablePersister(); err != nil {
xlog.Warn("Failed to start observable persister", "error", err)
} else {
xlog.Info("Observable persister started")
}
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
distStores, err := distributed.InitStores(authDB)
if err != nil {
return nil, fmt.Errorf("initializing distributed stores: %w", err)
}
// Initialize file manager with local cache
cacheDir := cfg.DataPath + "/cache"
fileMgr, err := storage.NewFileManager(store, cacheDir)
if err != nil {
return nil, fmt.Errorf("initializing file manager: %w", err)
}
xlog.Info("File manager initialized", "cacheDir", cacheDir)
// Create FileStager for distributed file transfer
var fileStager nodes.FileStager
if cfg.Distributed.StorageURL != "" {
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
xlog.Info("File stager initialized (S3+NATS)")
} else {
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
node, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
if node.HTTPAddress == "" {
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
}
return node.HTTPAddress, nil
}, cfg.Distributed.RegistrationToken)
xlog.Info("File stager initialized (HTTP direct transfer)")
}
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
// All dependencies ready — build SmartRouter with all options at once
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
})
// Create ModelRouterAdapter to wire into ModelLoader
modelAdapter := nodes.NewModelRouterAdapter(router)
success = true
return &DistributedServices{
Nats: natsClient,
Store: store,
Registry: registry,
Router: router,
Health: healthMon,
JobStore: jobStore,
Dispatcher: dispatcher,
AgentStore: agentStore,
AgentBridge: agentBridge,
DistStores: distStores,
FileMgr: fileMgr,
FileStager: fileStager,
ModelAdapter: modelAdapter,
Unloader: remoteUnloader,
}, nil
}
func isPostgresURL(url string) bool {
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/edgevpn/pkg/node"
"github.com/mudler/xlog"
@@ -146,22 +146,14 @@ func (a *Application) RestartP2P() error {
return fmt.Errorf("P2P token is not set")
}
// Create new context for P2P
ctx, cancel := context.WithCancel(appConfig.Context)
a.p2pCtx = ctx
a.p2pCancel = cancel
// Get API address from config
address := appConfig.APIAddress
if address == "" {
address = "127.0.0.1:8080" // default
}
// Start P2P stack in a goroutine
// Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel
go func() {
if err := a.StartP2P(); err != nil {
xlog.Error("Failed to start P2P stack", "error", err)
cancel() // Cancel context on error
if a.p2pCancel != nil {
a.p2pCancel()
}
}
}()
xlog.Info("P2P stack restarted with new settings")
@@ -228,7 +220,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error {
continue
}
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uuid.String(),
GalleryElementName: model,
Galleries: app.ApplicationConfig().Galleries,

View File

@@ -13,11 +13,15 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/mudler/xlog"
)
@@ -101,7 +105,7 @@ func New(opts ...config.AppOption) (*Application, error) {
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
}
application.authDB = authDB
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL))
// Start session and expired API key cleanup goroutine
go func() {
@@ -123,12 +127,92 @@ func New(opts ...config.AppOption) (*Application, error) {
}()
}
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
if application.authDB != nil && application.agentJobService != nil {
dbJobStore, err := jobs.NewJobStore(application.authDB)
if err != nil {
xlog.Error("Failed to create job store for auth DB", "error", err)
} else {
application.agentJobService.SetDistributedJobStore(dbJobStore)
}
}
// Initialize distributed mode services (NATS, object storage, node registry)
distSvc, err := initDistributed(options, application.authDB)
if err != nil {
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
}
if distSvc != nil {
application.distributed = distSvc
// Wire remote model unloader so ShutdownModel works for remote nodes
// Uses NATS to tell serve-backend nodes to Free + kill their backend process
application.modelLoader.SetRemoteUnloader(distSvc.Unloader)
// Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode
application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter())
// Wire DistributedModelStore so shutdown/list/watchdog can find remote models
distStore := nodes.NewDistributedModelStore(
model.NewInMemoryModelStore(),
distSvc.Registry,
)
application.modelLoader.SetModelStore(distStore)
// Start health monitor
distSvc.Health.Start(options.Context)
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
// but does NOT set a workerFn — agent workers consume jobs from the same NATS queue.
// Wire model config loader so job events include model config for agent workers
distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader)
// Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched
if err := distSvc.Dispatcher.Start(options.Context); err != nil {
return nil, fmt.Errorf("starting job dispatcher: %w", err)
}
// Start ephemeral file cleanup
storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0)
// Wire distributed backends into AgentJobService (before Start)
if application.agentJobService != nil {
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
}
// Wire skill store into AgentPoolService (wired at pool start time via closure)
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
// Wire NATS and gallery store into GalleryService for cross-instance progress/cancel
if application.galleryService != nil {
application.galleryService.SetNATSClient(distSvc.Nats)
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
// Clean up stale in-progress operations from previous crashed instances
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
xlog.Warn("Failed to clean stale gallery operations", "error", err)
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
}
// Wire distributed model/backend managers so delete propagates to workers
application.galleryService.SetModelManager(
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
)
application.galleryService.SetBackendManager(
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
)
}
}
// Start AgentJobService (after distributed wiring so it knows whether to use local or NATS)
if application.agentJobService != nil {
if err := application.agentJobService.Start(options.Context); err != nil {
return nil, fmt.Errorf("starting agent job service: %w", err)
}
}
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
xlog.Error("error installing models", "error", err)
}
for _, backend := range options.ExternalBackends {
if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
xlog.Error("error installing external backend", "error", err)
}
}
@@ -154,13 +238,13 @@ func New(opts ...config.AppOption) (*Application, error) {
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
return nil, err
}
}
@@ -184,6 +268,7 @@ func New(opts ...config.AppOption) (*Application, error) {
go func() {
<-options.Context.Done()
xlog.Debug("Context canceled, shutting down")
application.distributed.Shutdown()
err := application.ModelLoader().StopAllGRPC()
if err != nil {
xlog.Error("error while stopping all grpc backends", "error", err)
@@ -207,7 +292,7 @@ func New(opts ...config.AppOption) (*Application, error) {
var backendErr error
_, backendErr = application.ModelLoader().Load(o...)
if backendErr != nil {
return nil, err
return nil, backendErr
}
}
}

View File

@@ -13,9 +13,9 @@ import (
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc/proto"
@@ -27,7 +27,7 @@ type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
Logprobs *schema.Logprobs // Logprobs from the backend response
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
}
@@ -47,14 +47,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS)
if err != nil {
return nil, err
}
if !slices.Contains(modelNames, c.Name) {
modelName := c.Name
if modelName == "" {
modelName = c.Model
}
if !slices.Contains(modelNames, modelName) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
//return nil, err
@@ -252,12 +256,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
traceData := map[string]any{
"chat_template": c.TemplateConfig.Chat,
"chat_template": c.TemplateConfig.Chat,
"function_template": c.TemplateConfig.Functions,
"streaming": tokenCallback != nil,
"images_count": len(images),
"videos_count": len(videos),
"audios_count": len(audios),
"streaming": tokenCallback != nil,
"images_count": len(images),
"videos_count": len(videos),
"audios_count": len(audios),
}
if len(messages) > 0 {

View File

@@ -1,7 +1,7 @@
package backend
import (
"math/rand"
"math/rand/v2"
"os"
"path/filepath"
"strings"
@@ -86,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 {
}
if seed == config.RAND_SEED {
seed = rand.Int31()
seed = rand.Int32()
}
return seed

View File

@@ -4,8 +4,8 @@ import (
"time"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/model"
)

View File

@@ -1,40 +1,40 @@
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
opts := ModelOptions(modelConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
opts := ModelOptions(modelConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}

View File

@@ -8,11 +8,11 @@ import (
"os/signal"
"syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAGI/core/state"
coreTypes "github.com/mudler/LocalAGI/core/types"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/xlog"
)
@@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
appConfig := r.buildAppConfig()
poolService, err := services.NewAgentPoolService(appConfig)
poolService, err := agentpool.NewAgentPoolService(appConfig)
if err != nil {
return fmt.Errorf("failed to create agent pool service: %w", err)
}

463
core/cli/agent_worker.go Normal file
View File

@@ -0,0 +1,463 @@
package cli
import (
"cmp"
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/workerregistry"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/core/services/jobs"
mcpRemote "github.com/mudler/LocalAI/core/services/mcp"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/cogito"
"github.com/mudler/cogito/clients"
"github.com/mudler/xlog"
)
// AgentWorkerCMD starts a dedicated agent worker process for distributed mode.
// It registers with the frontend, subscribes to the NATS agent execution queue,
// and executes agent chats using cogito. The worker is a pure executor — it
// receives the full agent config and skills in the NATS job payload, so it
// does not need direct database access.
//
// Usage:
//
// localai agent-worker --nats-url nats://... --register-to http://localai:8080
type AgentWorkerCMD struct {
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
// Registration (required)
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
// API access
APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"`
APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"`
// NATS subjects
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
// Timeouts
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
}
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
// Resolve API URL
apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/"))
// Register with frontend
regClient := &workerregistry.RegistrationClient{
FrontendURL: cmd.RegisterTo,
RegistrationToken: cmd.RegistrationToken,
}
nodeName := cmd.NodeName
if nodeName == "" {
hostname, _ := os.Hostname()
nodeName = "agent-" + hostname
}
registrationBody := map[string]any{
"name": nodeName,
"node_type": "agent",
}
if cmd.RegistrationToken != "" {
registrationBody["token"] = cmd.RegistrationToken
}
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
// Use provisioned API token if none was set
if cmd.APIToken == "" {
cmd.APIToken = apiToken
}
// Start heartbeat
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
if err != nil && cmd.HeartbeatInterval != "" {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
// Connect to NATS
natsClient, err := messaging.New(cmd.NatsURL)
if err != nil {
return fmt.Errorf("connecting to NATS: %w", err)
}
defer natsClient.Close()
// Create event bridge for publishing results back via NATS
eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID)
// Start cancel listener
cancelSub, err := eventBridge.StartCancelListener()
if err != nil {
xlog.Warn("Failed to start cancel listener", "error", err)
} else {
defer cancelSub.Unsubscribe()
}
// Create and start the NATS dispatcher.
// No ConfigProvider or SkillStore needed — config and skills arrive in the job payload.
dispatcher := agents.NewNATSDispatcher(
natsClient,
eventBridge,
nil, // no ConfigProvider: config comes in the enriched NATS payload
apiURL, cmd.APIToken,
cmd.Subject, cmd.Queue,
0, // no concurrency limit (CLI worker)
)
if err := dispatcher.Start(shutdownCtx); err != nil {
return fmt.Errorf("starting dispatcher: %w", err)
}
// Subscribe to MCP tool execution requests (load-balanced across workers).
// The frontend routes model-level MCP tool calls here via NATS request-reply.
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
handleMCPToolRequest(data, reply)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err)
}
// Subscribe to MCP discovery requests (load-balanced across workers).
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
handleMCPDiscoveryRequest(data, reply)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err)
}
// Subscribe to MCP CI job execution (load-balanced across agent workers).
// In distributed mode, MCP CI jobs are routed here because the frontend
// cannot create MCP sessions (e.g., stdio servers using docker).
mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout)
if err != nil && cmd.MCPCIJobTimeout != "" {
xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err)
}
mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout)
if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) {
handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err)
}
// Subscribe to backend stop events to clean up cached MCP sessions.
// In the main application this is done via ml.OnModelUnload, but the agent
// worker has no model loader — we listen for the NATS stop event instead.
if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) {
var req struct {
Backend string `json:"backend"`
}
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
mcpTools.CloseMCPSessions(req.Backend)
}
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err)
}
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
xlog.Info("Shutting down agent worker")
shutdownCancel() // stop heartbeat loop immediately
dispatcher.Stop()
mcpTools.CloseAllMCPSessions()
regClient.GracefulDeregister(nodeID)
return nil
}
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
// The worker creates/caches MCP sessions from the serialized config and executes the tool.
func handleMCPToolRequest(data []byte, reply func([]byte)) {
var req mcpRemote.MCPToolRequest
if err := json.Unmarshal(data, &req); err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout)
defer cancel()
// Create/cache named MCP sessions from the provided config
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
if err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err))
return
}
// Discover tools to find the right session
tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions)
if err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err))
return
}
// Execute the tool
argsJSON, _ := json.Marshal(req.Arguments)
result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON))
if err != nil {
sendMCPToolReply(reply, "", err.Error())
return
}
sendMCPToolReply(reply, result, "")
}
func sendMCPToolReply(reply func([]byte), result, errMsg string) {
resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg}
data, _ := json.Marshal(resp)
reply(data)
}
// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery.
func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) {
var req mcpRemote.MCPDiscoveryRequest
if err := json.Unmarshal(data, &req); err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout)
defer cancel()
// Create/cache named MCP sessions
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
if err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err))
return
}
// List servers with their tools/prompts/resources
serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions)
if err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err))
return
}
// Also get tool function schemas for the frontend
tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions)
var toolDefs []mcpRemote.MCPToolDef
for _, t := range tools {
toolDefs = append(toolDefs, mcpRemote.MCPToolDef{
ServerName: t.ServerName,
ToolName: t.ToolName,
Function: t.Function,
})
}
// Convert server infos
var servers []mcpRemote.MCPServerInfo
for _, s := range serverInfos {
servers = append(servers, mcpRemote.MCPServerInfo{
Name: s.Name,
Type: s.Type,
Tools: s.Tools,
Prompts: s.Prompts,
Resources: s.Resources,
})
}
sendMCPDiscoveryReply(reply, servers, toolDefs, "")
}
func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) {
resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg}
data, _ := json.Marshal(resp)
reply(data)
}
// handleMCPCIJob processes an MCP CI job on the agent worker.
// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference.
func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) {
var evt jobs.JobEvent
if err := json.Unmarshal(data, &evt); err != nil {
xlog.Error("Failed to unmarshal job event", "error", err)
return
}
job := evt.Job
task := evt.Task
if job == nil || task == nil {
xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID)
publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event")
return
}
modelCfg := evt.ModelConfig
if modelCfg == nil {
publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event")
return
}
xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model)
// Publish running status
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, Status: "running", Message: "Job started on agent worker",
})
// Parse MCP config
if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" {
publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model")
return
}
remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML()
if err != nil {
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err))
return
}
// Create MCP sessions locally (agent worker has docker)
sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio)
if err != nil || len(sessions) == 0 {
errMsg := "no working MCP servers found"
if err != nil {
errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err)
}
publishJobResult(natsClient, evt.JobID, "failed", "", errMsg)
return
}
// Build prompt from template
prompt := task.Prompt
if task.CronParametersJSON != "" {
var params map[string]string
if err := json.Unmarshal([]byte(task.CronParametersJSON), &params); err != nil {
xlog.Warn("Failed to unmarshal parameters", "error", err)
}
for k, v := range params {
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
}
}
if job.ParametersJSON != "" {
var params map[string]string
if err := json.Unmarshal([]byte(job.ParametersJSON), &params); err != nil {
xlog.Warn("Failed to unmarshal parameters", "error", err)
}
for k, v := range params {
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
}
}
// Create LLM client pointing back to the frontend API
llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL)
// Build cogito options
ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout)
defer cancel()
// Update job status to running in DB
publishJobStatus(natsClient, evt.JobID, "running", "")
// Buffer stream tokens and flush as complete blocks
var reasoningBuf, contentBuf strings.Builder
var lastStreamType cogito.StreamEventType
flushStreamBuf := func() {
if reasoningBuf.Len() > 0 {
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(),
})
reasoningBuf.Reset()
}
if contentBuf.Len() > 0 {
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(),
})
contentBuf.Reset()
}
}
cogitoOpts := modelCfg.BuildCogitoOptions()
cogitoOpts = append(cogitoOpts,
cogito.WithContext(ctx),
cogito.WithMCPs(sessions...),
cogito.WithStatusCallback(func(status string) {
flushStreamBuf()
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "status", TraceContent: status,
})
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
flushStreamBuf()
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result),
})
}),
cogito.WithStreamCallback(func(ev cogito.StreamEvent) {
// Flush if stream type changed (e.g., reasoning → content)
if ev.Type != lastStreamType {
flushStreamBuf()
lastStreamType = ev.Type
}
switch ev.Type {
case cogito.StreamEventReasoning:
reasoningBuf.WriteString(ev.Content)
case cogito.StreamEventContent:
contentBuf.WriteString(ev.Content)
case cogito.StreamEventToolCall:
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs),
})
}
}),
)
// Execute via cogito
fragment := cogito.NewEmptyFragment()
fragment = fragment.AddMessage("user", prompt)
f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...)
flushStreamBuf() // flush any remaining buffered tokens
if err != nil {
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err))
return
}
result := ""
if msg := f.LastMessage(); msg != nil {
result = msg.Content
}
publishJobResult(natsClient, evt.JobID, "completed", result, "")
xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result))
}
func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) {
jobs.PublishJobProgress(nc, jobID, status, message)
}
func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) {
jobs.PublishJobResult(nc, jobID, status, result, errMsg)
}

View File

@@ -8,7 +8,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
@@ -103,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState)
err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -15,7 +15,9 @@ var CLI struct {
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
P2PWorker worker.Worker `cmd:"" name:"p2p-worker" help:"Run workers to distribute workload via p2p (llama.cpp-only)"`
Worker WorkerCMD `cmd:"" help:"Start a worker for distributed mode (generic, backend-agnostic)"`
AgentWorker AgentWorkerCMD `cmd:"" name:"agent-worker" help:"Start an agent worker for distributed mode (executes agent chats via NATS)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`

View File

@@ -186,9 +186,9 @@ _local_ai_completions()
}
subcmds := []string{}
for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2)
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
subcmds = append(subcmds, parts[1])
parent, child, found := strings.Cut(sub.fullName, " ")
if found && parent == cmd.name && !strings.Contains(child, " ") {
subcmds = append(subcmds, child)
}
}
if len(subcmds) > 0 {
@@ -279,8 +279,8 @@ _local_ai() {
// Check for subcommands
subcmds := []commandInfo{}
for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2)
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
parent, child, found := strings.Cut(sub.fullName, " ")
if found && parent == cmd.name && !strings.Contains(child, " ") {
subcmds = append(subcmds, sub)
}
}
@@ -289,11 +289,11 @@ _local_ai() {
sb.WriteString(" local -a subcmds\n")
sb.WriteString(" subcmds=(\n")
for _, sub := range subcmds {
parts := strings.SplitN(sub.fullName, " ", 2)
_, child, _ := strings.Cut(sub.fullName, " ")
help := strings.ReplaceAll(sub.help, "'", "'\\''")
help = strings.ReplaceAll(help, "[", "\\[")
help = strings.ReplaceAll(help, "]", "\\]")
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help))
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", child, help))
}
sb.WriteString(" )\n")
sb.WriteString(" _describe -t commands 'subcommands' subcmds\n")
@@ -372,10 +372,10 @@ func generateFishCompletion(app *kong.Application) string {
// Subcommands
for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2)
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
parent, child, found := strings.Cut(sub.fullName, " ")
if found && parent == cmd.name && !strings.Contains(child, " ") {
help := strings.ReplaceAll(sub.help, "'", "\\'")
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help))
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, child, help))
}
}

View File

@@ -9,8 +9,8 @@ import (
func getTestApp() *kong.Application {
var testCLI struct {
Run struct{} `cmd:"" help:"Run the server"`
Models struct {
Run struct{} `cmd:"" help:"Run the server"`
Models struct {
List struct{} `cmd:"" help:"List models"`
Install struct{} `cmd:"" help:"Install a model"`
} `cmd:"" help:"Manage models"`

View File

@@ -8,7 +8,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
@@ -80,7 +80,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err
}
galleryService := services.NewGalleryService(&config.ApplicationConfig{
galleryService := galleryop.NewGalleryService(&config.ApplicationConfig{
SystemState: systemState,
}, model.NewModelLoader(systemState))
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)

View File

@@ -44,9 +44,9 @@ type RunCMD struct {
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
@@ -100,7 +100,7 @@ type RunCMD struct {
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
// Agent Pool (LocalAGI)
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"`
AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"`
AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"`
@@ -109,17 +109,17 @@ type RunCMD struct {
AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"`
AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"`
AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"`
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
// Authentication
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
@@ -136,6 +136,18 @@ type RunCMD struct {
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
// Distributed / Horizontal Scaling
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
Version bool
}
@@ -210,6 +222,38 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}),
}
// Distributed mode
if r.Distributed {
opts = append(opts, config.EnableDistributed)
}
if r.InstanceID != "" {
opts = append(opts, config.WithDistributedInstanceID(r.InstanceID))
}
if r.NatsURL != "" {
opts = append(opts, config.WithNatsURL(r.NatsURL))
}
if r.StorageURL != "" {
opts = append(opts, config.WithStorageURL(r.StorageURL))
}
if r.StorageBucket != "" {
opts = append(opts, config.WithStorageBucket(r.StorageBucket))
}
if r.StorageRegion != "" {
opts = append(opts, config.WithStorageRegion(r.StorageRegion))
}
if r.StorageAccessKey != "" {
opts = append(opts, config.WithStorageAccessKey(r.StorageAccessKey))
}
if r.StorageSecretKey != "" {
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
}
if r.RegistrationToken != "" {
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
}
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}
if r.DisableMetricsEndpoint {
opts = append(opts, config.DisableMetricsEndpoint)
}
@@ -218,10 +262,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.DisableRuntimeSettings)
}
if r.EnableTracing {
opts = append(opts, config.EnableTracing)
}
if r.EnableTracing {
opts = append(opts, config.EnableTracing)
}
@@ -479,6 +519,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if err := app.ModelLoader().StopAllGRPC(); err != nil {
xlog.Error("error while stopping all grpc backends", "error", err)
}
// Clean up distributed services (idempotent — safe if already called)
if d := app.Distributed(); d != nil {
d.Shutdown()
}
})
// Start the agent pool after the HTTP server is listening, because

View File

@@ -12,7 +12,6 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/format"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog"
@@ -80,7 +79,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
switch t.ResponseFormat {
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat))
case schema.TranscriptionResponseFormatJson:
tr.Segments = nil
fallthrough

897
core/cli/worker.go Normal file
View File

@@ -0,0 +1,897 @@
package cli
import (
"cmp"
"context"
"encoding/json"
"fmt"
"maps"
"net"
"os"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"syscall"
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/workerregistry"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsysinfo"
process "github.com/mudler/go-processmanager"
"github.com/mudler/xlog"
)
// isPathAllowed checks if path is within one of the allowed directories.
func isPathAllowed(path string, allowedDirs []string) bool {
absPath, err := filepath.Abs(path)
if err != nil {
return false
}
resolved, err := filepath.EvalSymlinks(absPath)
if err != nil {
// Path may not exist yet; use the absolute path
resolved = absPath
}
for _, dir := range allowedDirs {
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir {
return true
}
}
return false
}
// WorkerCMD starts a generic worker process for distributed mode.
// Workers are backend-agnostic — they wait for backend.install NATS events
// from the SmartRouter to install and start the required backend.
//
// NATS is required. The worker acts as a process supervisor:
// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success
// - Receives backend.stop → stops the gRPC process
// - Receives stop → full shutdown (deregister + exit)
//
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
type WorkerCMD struct {
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
// HTTP file transfer
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
// Registration (required)
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
// S3 storage for distributed file transfer
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"`
StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"`
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
}
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting worker", "addr", cmd.Addr)
systemState, err := system.GetSystemState(
system.WithModelPath(cmd.ModelsPath),
system.WithBackendPath(cmd.BackendsPath),
system.WithBackendSystemPath(cmd.BackendsSystemPath),
)
if err != nil {
return fmt.Errorf("getting system state: %w", err)
}
ml := model.NewModelLoader(systemState)
ml.SetBackendLoggingEnabled(true)
// Register already-installed backends
gallery.RegisterBackends(systemState, ml)
// Parse galleries config
var galleries []config.Gallery
if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil {
xlog.Warn("Failed to parse backend galleries", "error", err)
}
// Self-registration with frontend (with retry)
regClient := &workerregistry.RegistrationClient{
FrontendURL: cmd.RegisterTo,
RegistrationToken: cmd.RegistrationToken,
}
registrationBody := cmd.registrationBody()
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("failed to register with frontend: %w", err)
}
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
if err != nil && cmd.HeartbeatInterval != "" {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Start HTTP file transfer server
httpAddr := cmd.resolveHTTPAddr()
stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging")
dataDir := filepath.Join(cmd.ModelsPath, "..", "data")
httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs())
if err != nil {
return fmt.Errorf("starting HTTP file transfer server: %w", err)
}
// Connect to NATS
xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL))
natsClient, err := messaging.New(cmd.NatsURL)
if err != nil {
nodes.ShutdownFileTransferServer(httpServer)
return fmt.Errorf("connecting to NATS: %w", err)
}
defer natsClient.Close()
// Start heartbeat goroutine (after NATS is connected so IsConnected check works)
go func() {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-shutdownCtx.Done():
return
case <-ticker.C:
if !natsClient.IsConnected() {
xlog.Warn("Skipping heartbeat: NATS disconnected")
continue
}
body := cmd.heartbeatBody()
if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil {
xlog.Warn("Heartbeat failed", "error", err)
}
}
}
}()
// Process supervisor — manages multiple backend gRPC processes on different ports
basePort := 50051
if cmd.Addr != "" {
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
if p, err := strconv.Atoi(portStr); err == nil {
basePort = p
}
}
}
// Buffered so NATS stop handler can send without blocking
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Set the registration token once before any backends are started
if cmd.RegistrationToken != "" {
os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken)
}
supervisor := &backendSupervisor{
cmd: cmd,
ml: ml,
systemState: systemState,
galleries: galleries,
nodeID: nodeID,
nats: natsClient,
sigCh: sigCh,
processes: make(map[string]*backendProcess),
nextPort: basePort,
}
supervisor.subscribeLifecycleEvents()
// Subscribe to file staging NATS subjects if S3 is configured
if cmd.StorageURL != "" {
if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil {
xlog.Error("Failed to subscribe to file staging subjects", "error", err)
}
}
xlog.Info("Worker ready, waiting for backend.install events")
<-sigCh
xlog.Info("Shutting down worker")
shutdownCancel() // stop heartbeat loop immediately
regClient.GracefulDeregister(nodeID)
supervisor.stopAllBackends()
nodes.ShutdownFileTransferServer(httpServer)
return nil
}
// subscribeFileStaging subscribes to NATS file staging subjects for this node.
func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error {
// Create FileManager with same S3 config as the frontend
// TODO: propagate a caller-provided context once WorkerCMD carries one
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
Endpoint: cmd.StorageURL,
Region: cmd.StorageRegion,
Bucket: cmd.StorageBucket,
AccessKeyID: cmd.StorageAccessKey,
SecretAccessKey: cmd.StorageSecretKey,
ForcePathStyle: true,
})
if err != nil {
return fmt.Errorf("initializing S3 store: %w", err)
}
cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache")
fm, err := storage.NewFileManager(s3Store, cacheDir)
if err != nil {
return fmt.Errorf("initializing file manager: %w", err)
}
// Subscribe: files.ensure — download S3 key to local, reply with local path
natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
Key string `json:"key"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]string{"error": "invalid request"})
return
}
localPath, err := fm.Download(context.Background(), req.Key)
if err != nil {
xlog.Error("File ensure failed", "key", req.Key, "error", err)
replyJSON(reply, map[string]string{"error": err.Error()})
return
}
xlog.Debug("File ensured locally", "key", req.Key, "path", localPath)
replyJSON(reply, map[string]string{"local_path": localPath})
})
// Subscribe: files.stage — upload local path to S3, reply with key
natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
LocalPath string `json:"local_path"`
Key string `json:"key"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]string{"error": "invalid request"})
return
}
allowedDirs := []string{cacheDir}
if cmd.ModelsPath != "" {
allowedDirs = append(allowedDirs, cmd.ModelsPath)
}
if !isPathAllowed(req.LocalPath, allowedDirs) {
replyJSON(reply, map[string]string{"error": "path outside allowed directories"})
return
}
if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil {
xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err)
replyJSON(reply, map[string]string{"error": err.Error()})
return
}
xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key)
replyJSON(reply, map[string]string{"key": req.Key})
})
// Subscribe: files.temp — allocate temp file, reply with local path
natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) {
tmpDir := filepath.Join(cacheDir, "staging-tmp")
if err := os.MkdirAll(tmpDir, 0750); err != nil {
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)})
return
}
f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp")
if err != nil {
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)})
return
}
localPath := f.Name()
f.Close()
xlog.Debug("Allocated temp file", "path", localPath)
replyJSON(reply, map[string]string{"local_path": localPath})
})
// Subscribe: files.listdir — list files in a local directory, reply with relative paths
natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
KeyPrefix string `json:"key_prefix"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]any{"error": "invalid request"})
return
}
// Resolve key prefix to local directory
dirPath := filepath.Join(cacheDir, req.KeyPrefix)
if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" {
dirPath = filepath.Join(cmd.ModelsPath, rel)
} else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok {
dirPath = filepath.Join(cacheDir, "..", "data", rel)
}
// Sanitize to prevent directory traversal via crafted key_prefix
dirPath = filepath.Clean(dirPath)
cleanCache := filepath.Clean(cacheDir)
cleanModels := filepath.Clean(cmd.ModelsPath)
cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data"))
if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) ||
dirPath == cleanCache ||
(cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) ||
dirPath == cleanModels ||
strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) ||
dirPath == cleanData) {
replyJSON(reply, map[string]any{"error": "invalid key prefix"})
return
}
var files []string
filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil
}
if !d.IsDir() {
rel, err := filepath.Rel(dirPath, path)
if err == nil {
files = append(files, rel)
}
}
return nil
})
xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files))
replyJSON(reply, map[string]any{"files": files})
})
xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID)
return nil
}
// replyJSON marshals v to JSON and calls the reply function.
func replyJSON(reply func([]byte), v any) {
data, err := json.Marshal(v)
if err != nil {
xlog.Error("Failed to marshal NATS reply", "error", err)
data = []byte(`{"error":"internal marshal error"}`)
}
reply(data)
}
// backendProcess represents a single gRPC backend process.
type backendProcess struct {
proc *process.Process
backend string
addr string // gRPC address (host:port)
}
// backendSupervisor manages multiple backend gRPC processes on different ports.
// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port.
type backendSupervisor struct {
cmd *WorkerCMD
ml *model.ModelLoader
systemState *system.SystemState
galleries []config.Gallery
nodeID string
nats messaging.MessagingClient
sigCh chan<- os.Signal // send shutdown signal instead of os.Exit
mu sync.Mutex
processes map[string]*backendProcess // key: backend name
nextPort int // next available port for new backends
freePorts []int // ports freed by stopBackend, reused before nextPort
}
// startBackend starts a gRPC backend process on a dynamically allocated port.
// Returns the gRPC address.
func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) {
s.mu.Lock()
// Already running?
if bp, ok := s.processes[backend]; ok {
if bp.proc != nil && bp.proc.IsAlive() {
s.mu.Unlock()
return bp.addr, nil
}
// Process died — clean up and restart
xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend)
delete(s.processes, backend)
}
// Allocate port — recycle freed ports first, then grow upward from basePort
var port int
if len(s.freePorts) > 0 {
port = s.freePorts[len(s.freePorts)-1]
s.freePorts = s.freePorts[:len(s.freePorts)-1]
} else {
port = s.nextPort
s.nextPort++
}
bindAddr := fmt.Sprintf("0.0.0.0:%d", port)
clientAddr := fmt.Sprintf("127.0.0.1:%d", port)
proc, err := s.ml.StartProcess(backendPath, backend, bindAddr)
if err != nil {
s.mu.Unlock()
return "", fmt.Errorf("starting backend process: %w", err)
}
s.processes[backend] = &backendProcess{
proc: proc,
backend: backend,
addr: clientAddr,
}
xlog.Info("Backend process started", "backend", backend, "addr", clientAddr)
// Capture reference before unlocking for race-safe health check.
// Another goroutine could stopBackend and recycle the port while we poll.
bp := s.processes[backend]
s.mu.Unlock()
// Wait for the gRPC server to be ready
client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken)
for range 20 {
time.Sleep(200 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
if ok, _ := client.HealthCheck(ctx); ok {
cancel()
// Verify the process wasn't stopped/replaced while health-checking
s.mu.Lock()
currentBP, exists := s.processes[backend]
s.mu.Unlock()
if !exists || currentBP != bp {
return "", fmt.Errorf("backend %s was stopped during startup", backend)
}
xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr)
return clientAddr, nil
}
cancel()
}
xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr)
return clientAddr, nil
}
// stopBackend stops a specific backend's gRPC process.
func (s *backendSupervisor) stopBackend(backend string) {
s.mu.Lock()
bp, ok := s.processes[backend]
if !ok || bp.proc == nil {
s.mu.Unlock()
return
}
// Clean up map and recycle port while holding lock
delete(s.processes, backend)
if _, portStr, err := net.SplitHostPort(bp.addr); err == nil {
if p, err := strconv.Atoi(portStr); err == nil {
s.freePorts = append(s.freePorts, p)
}
}
s.mu.Unlock()
// Network I/O outside the lock
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
}
}
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
if err := bp.proc.Stop(); err != nil {
xlog.Error("Error stopping backend process", "backend", backend, "error", err)
}
}
// stopAllBackends stops all running backend processes.
func (s *backendSupervisor) stopAllBackends() {
s.mu.Lock()
backends := slices.Collect(maps.Keys(s.processes))
s.mu.Unlock()
for _, b := range backends {
s.stopBackend(b)
}
}
// isRunning returns whether a specific backend process is currently running.
func (s *backendSupervisor) isRunning(backend string) bool {
s.mu.Lock()
defer s.mu.Unlock()
bp, ok := s.processes[backend]
return ok && bp.proc != nil && bp.proc.IsAlive()
}
// getAddr returns the gRPC address for a running backend, or empty string.
func (s *backendSupervisor) getAddr(backend string) string {
s.mu.Lock()
defer s.mu.Unlock()
if bp, ok := s.processes[backend]; ok {
return bp.addr
}
return ""
}
// installBackend handles the backend.install flow:
// 1. If already running for this model, return existing address
// 2. Install backend from gallery (if not already installed)
// 3. Find backend binary
// 4. Start gRPC process on a new port
// Returns the gRPC address of the backend process.
func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) {
// Process key: use ModelID if provided (per-model process), else backend name
processKey := req.ModelID
if processKey == "" {
processKey = req.Backend
}
// If already running for this model, return its address
if addr := s.getAddr(processKey); addr != "" {
xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr)
return addr, nil
}
// Parse galleries from request (override local config if provided)
galleries := s.galleries
if req.BackendGalleries != "" {
var reqGalleries []config.Gallery
if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil {
galleries = reqGalleries
}
}
// Try to find the backend binary
backendPath := s.findBackend(req.Backend)
if backendPath == "" {
// Backend not found locally — try auto-installing from gallery
xlog.Info("Backend not found locally, attempting gallery install", "backend", req.Backend)
if err := gallery.InstallBackendFromGallery(
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, false,
); err != nil {
return "", fmt.Errorf("installing backend from gallery: %w", err)
}
// Re-register after install and retry
gallery.RegisterBackends(s.systemState, s.ml)
backendPath = s.findBackend(req.Backend)
}
if backendPath == "" {
return "", fmt.Errorf("backend %q not found after install attempt", req.Backend)
}
xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey)
// Start the gRPC process on a new port (keyed by model, not just backend)
return s.startBackend(processKey, backendPath)
}
// findBackend looks for the backend binary in the backends path and system path.
func (s *backendSupervisor) findBackend(backend string) string {
candidates := []string{
filepath.Join(s.cmd.BackendsPath, backend),
filepath.Join(s.cmd.BackendsPath, backend, backend),
filepath.Join(s.cmd.BackendsSystemPath, backend),
filepath.Join(s.cmd.BackendsSystemPath, backend, backend),
}
if uri := s.ml.GetExternalBackend(backend); uri != "" {
if fi, err := os.Stat(uri); err == nil && !fi.IsDir() {
return uri
}
}
for _, path := range candidates {
fi, err := os.Stat(path)
if err == nil && !fi.IsDir() {
return path
}
}
return ""
}
// subscribeLifecycleEvents subscribes to NATS backend lifecycle events.
func (s *backendSupervisor) subscribeLifecycleEvents() {
// backend.install — install backend + start gRPC process (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.install event")
var req messaging.BackendInstallRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
addr, err := s.installBackend(req)
if err != nil {
xlog.Error("Failed to install backend via NATS", "error", err)
resp := messaging.BackendInstallReply{Success: false, Error: err.Error()}
replyJSON(reply, resp)
return
}
// Return the gRPC address so the router knows which port to use
advertiseAddr := addr
if s.cmd.AdvertiseAddr != "" {
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
_, port, _ := net.SplitHostPort(addr)
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
advertiseAddr = net.JoinHostPort(advertiseHost, port)
}
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
replyJSON(reply, resp)
})
// backend.stop — stop a specific backend process
s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) {
// Try to parse backend name from payload; if empty, stop all
var req struct {
Backend string `json:"backend"`
}
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
xlog.Info("Received NATS backend.stop event", "backend", req.Backend)
s.stopBackend(req.Backend)
} else {
xlog.Info("Received NATS backend.stop event (all)")
s.stopAllBackends()
}
})
// backend.delete — stop backend + delete files (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.delete event")
var req messaging.BackendDeleteRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
// Stop if running this backend
if s.isRunning(req.Backend) {
s.stopBackend(req.Backend)
}
// Delete the backend files
if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil {
xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err)
resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()}
replyJSON(reply, resp)
return
}
// Re-register backends after deletion
gallery.RegisterBackends(s.systemState, s.ml)
resp := messaging.BackendDeleteReply{Success: true}
replyJSON(reply, resp)
})
// backend.list — list installed backends (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.list event")
backends, err := gallery.ListSystemBackends(s.systemState)
if err != nil {
resp := messaging.BackendListReply{Error: err.Error()}
replyJSON(reply, resp)
return
}
var infos []messaging.NodeBackendInfo
for name, b := range backends {
info := messaging.NodeBackendInfo{
Name: name,
IsSystem: b.IsSystem,
IsMeta: b.IsMeta,
}
if b.Metadata != nil {
info.InstalledAt = b.Metadata.InstalledAt
info.GalleryURL = b.Metadata.GalleryURL
}
infos = append(infos, info)
}
resp := messaging.BackendListReply{Backends: infos}
replyJSON(reply, resp)
})
// model.unload — call gRPC Free() to release GPU memory (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS model.unload event")
var req messaging.ModelUnloadRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
// Find the backend address for this model's backend type
// The request includes an Address field if the router knows which process to target
targetAddr := req.Address
if targetAddr == "" {
// Fallback: try all running backends
s.mu.Lock()
for _, bp := range s.processes {
targetAddr = bp.addr
break
}
s.mu.Unlock()
}
if targetAddr != "" {
// Best-effort gRPC Free()
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
}
}
}
resp := messaging.ModelUnloadReply{Success: true}
replyJSON(reply, resp)
})
// model.delete — remove model files from disk (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS model.delete event")
var req messaging.ModelDeleteRequest
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"})
return
}
if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil {
xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err)
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()})
return
}
replyJSON(reply, messaging.ModelDeleteReply{Success: true})
})
// stop — trigger the normal shutdown path via sigCh so deferred cleanup runs
s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) {
xlog.Info("Received NATS stop event — signaling shutdown")
select {
case s.sigCh <- syscall.SIGTERM:
default:
xlog.Debug("Shutdown already signaled, ignoring duplicate stop")
}
})
}
// advertiseAddr returns the address the frontend should use to reach this node.
func (cmd *WorkerCMD) advertiseAddr() string {
if cmd.AdvertiseAddr != "" {
return cmd.AdvertiseAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
}
return cmd.Addr
}
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports
// which grow upward from basePort.
func (cmd *WorkerCMD) resolveHTTPAddr() string {
if cmd.HTTPAddr != "" {
return cmd.HTTPAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if !ok {
return "0.0.0.0:50050"
}
portNum, _ := strconv.Atoi(port)
return fmt.Sprintf("%s:%d", host, portNum-1)
}
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
// this node for file transfer.
func (cmd *WorkerCMD) advertiseHTTPAddr() string {
if cmd.AdvertiseHTTPAddr != "" {
return cmd.AdvertiseHTTPAddr
}
httpAddr := cmd.resolveHTTPAddr()
host, port, ok := strings.Cut(httpAddr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
}
return httpAddr
}
// registrationBody builds the JSON body for node registration.
func (cmd *WorkerCMD) registrationBody() map[string]any {
nodeName := cmd.NodeName
if nodeName == "" {
hostname, err := os.Hostname()
if err != nil {
nodeName = fmt.Sprintf("node-%d", os.Getpid())
} else {
nodeName = hostname
}
}
// Detect GPU info for VRAM-aware scheduling
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
gpuVendor, _ := xsysinfo.DetectGPUVendor()
body := map[string]any{
"name": nodeName,
"address": cmd.advertiseAddr(),
"http_address": cmd.advertiseHTTPAddr(),
"total_vram": totalVRAM,
"available_vram": totalVRAM, // initially all VRAM is available
"gpu_vendor": gpuVendor,
}
// If no GPU detected, report system RAM so the scheduler/UI has capacity info
if totalVRAM == 0 {
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
body["total_ram"] = ramInfo.Total
body["available_ram"] = ramInfo.Available
}
}
if cmd.RegistrationToken != "" {
body["token"] = cmd.RegistrationToken
}
return body
}
// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads.
func (cmd *WorkerCMD) heartbeatBody() map[string]any {
var availVRAM uint64
aggregate := xsysinfo.GetGPUAggregateInfo()
if aggregate.TotalVRAM > 0 {
availVRAM = aggregate.FreeVRAM
} else {
// Fallback: report total as available (no usage tracking possible)
availVRAM, _ = xsysinfo.TotalAvailableVRAM()
}
body := map[string]any{
"available_vram": availVRAM,
}
// If no GPU, report system RAM usage instead
if aggregate.TotalVRAM == 0 {
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
body["available_ram"] = ramInfo.Available
}
}
return body
}

View File

@@ -0,0 +1,272 @@
// Package workerregistry provides a shared HTTP client for worker node
// registration, heartbeating, draining, and deregistration against a
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
package workerregistry
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/mudler/xlog"
)
// RegistrationClient talks to the frontend's /api/node/* endpoints.
type RegistrationClient struct {
FrontendURL string
RegistrationToken string
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
client *http.Client
clientOnce sync.Once
}
// httpTimeout returns the configured timeout or a sensible default.
func (c *RegistrationClient) httpTimeout() time.Duration {
if c.HTTPTimeout > 0 {
return c.HTTPTimeout
}
return 10 * time.Second
}
// httpClient returns the shared HTTP client, initializing it on first use.
func (c *RegistrationClient) httpClient() *http.Client {
c.clientOnce.Do(func() {
c.client = &http.Client{Timeout: c.httpTimeout()}
})
return c.client
}
// baseURL returns FrontendURL with any trailing slash stripped.
func (c *RegistrationClient) baseURL() string {
return strings.TrimRight(c.FrontendURL, "/")
}
// setAuth adds an Authorization header when a token is configured.
func (c *RegistrationClient) setAuth(req *http.Request) {
if c.RegistrationToken != "" {
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
}
}
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
}
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return "", "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", "", fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", "", fmt.Errorf("decoding response: %w", err)
}
return result.ID, result.APIToken, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, nil
}
if attempt == maxRetries {
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, err
}
// Heartbeat sends a single heartbeat POST with the given body.
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return fmt.Errorf("creating heartbeat request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
body := bodyFn()
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
xlog.Warn("Heartbeat failed", "error", err)
}
}
}
}
// Drain sets the node to draining status via POST /api/node/:id/drain.
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating drain request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
}
return nil
}
// WaitForDrain polls GET /api/node/:id/models until all models report 0
// in-flight requests, or until timeout elapses.
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
url := c.baseURL() + "/api/node/" + nodeID + "/models"
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
xlog.Warn("Failed to create drain poll request", "error", err)
return
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
xlog.Warn("Drain poll failed, will retry", "error", err)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
continue
}
var models []struct {
InFlight int `json:"in_flight"`
}
json.NewDecoder(resp.Body).Decode(&models)
resp.Body.Close()
total := 0
for _, m := range models {
total += m.InFlight
}
if total == 0 {
xlog.Info("All in-flight requests drained")
return
}
xlog.Info("Waiting for in-flight requests", "count", total)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
}
xlog.Warn("Drain timeout reached, proceeding with shutdown")
}
// Deregister marks the node as offline via POST /api/node/:id/deregister.
// The node row is preserved in the database so re-registration restores
// approval status.
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating deregister request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
}
return nil
}
// GracefulDeregister performs drain -> wait -> deregister in sequence.
// This is the standard shutdown sequence for backend workers.
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
if c.FrontendURL == "" || nodeID == "" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := c.Drain(ctx, nodeID); err != nil {
xlog.Warn("Failed to set drain status", "error", err)
} else {
c.WaitForDrain(ctx, nodeID, 30*time.Second)
}
if err := c.Deregister(ctx, nodeID); err != nil {
xlog.Error("Failed to deregister", "error", err)
} else {
xlog.Info("Deregistered from frontend")
}
}

View File

@@ -94,7 +94,7 @@ func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
}
// Helper function to perform a request without expecting a response body
func (c *StoreClient) doRequest(path string, data interface{}) error {
func (c *StoreClient) doRequest(path string, data any) error {
jsonData, err := json.Marshal(data)
if err != nil {
return err
@@ -120,7 +120,7 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
}
// Helper function to perform a request and parse the response body
func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) {
func (c *StoreClient) doRequestWithResponse(path string, data any) ([]byte, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err

View File

@@ -83,8 +83,8 @@ type ApplicationConfig struct {
APIAddress string
LlamaCPPTunnelCallback func(tunnels []string)
MLXTunnelCallback func(tunnels []string)
LlamaCPPTunnelCallback func(tunnels []string)
MLXTunnelCallback func(tunnels []string)
DisableRuntimeSettings bool
@@ -99,47 +99,50 @@ type ApplicationConfig struct {
// Authentication & Authorization
Auth AuthConfig
// Distributed / Horizontal Scaling
Distributed DistributedConfig
}
// AuthConfig holds configuration for user authentication and authorization.
type AuthConfig struct {
Enabled bool
DatabaseURL string // "postgres://..." or file path for SQLite
GitHubClientID string
GitHubClientSecret string
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
OIDCClientID string
OIDCClientSecret string
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
AdminEmail string // auto-promote to admin on login
RegistrationMode string // "open", "approval" (default when empty), "invite"
DisableLocalAuth bool // disable local email/password registration and login
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
Enabled bool
DatabaseURL string // "postgres://..." or file path for SQLite
GitHubClientID string
GitHubClientSecret string
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
OIDCClientID string
OIDCClientSecret string
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
AdminEmail string // auto-promote to admin on login
RegistrationMode string // "open", "approval" (default when empty), "invite"
DisableLocalAuth bool // disable local email/password registration and login
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
}
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
type AgentPoolConfig struct {
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
APIKey string // default: first API key from LocalAI config
DefaultModel string
MultimodalModel string
TranscriptionModel string
TranscriptionLanguage string
TTSModel string
Timeout string // default: "5m"
EnableSkills bool
EnableLogs bool
CustomActionsDir string
CollectionDBPath string
VectorEngine string // default: "chromem"
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
MaxChunkingSize int // default: 400
ChunkOverlap int // default: 0
DatabaseURL string
AgentHubURL string // default: "https://agenthub.localai.io"
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
APIKey string // default: first API key from LocalAI config
DefaultModel string
MultimodalModel string
TranscriptionModel string
TranscriptionLanguage string
TTSModel string
Timeout string // default: "5m"
EnableSkills bool
EnableLogs bool
CustomActionsDir string
CollectionDBPath string
VectorEngine string // default: "chromem"
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
MaxChunkingSize int // default: 400
ChunkOverlap int // default: 0
DatabaseURL string
AgentHubURL string // default: "https://agenthub.localai.io"
}
type AppOption func(*ApplicationConfig)
@@ -155,12 +158,12 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
TracingMaxItems: 1024,
AgentPool: AgentPoolConfig{
Enabled: true,
Timeout: "5m",
VectorEngine: "chromem",
EmbeddingModel: "granite-embedding-107m-multilingual",
Enabled: true,
Timeout: "5m",
VectorEngine: "chromem",
EmbeddingModel: "granite-embedding-107m-multilingual",
MaxChunkingSize: 400,
AgentHubURL: "https://agenthub.localai.io",
AgentHubURL: "https://agenthub.localai.io",
},
PathWithoutAuth: []string{
"/static/",
@@ -904,40 +907,40 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath
return RuntimeSettings{
WatchdogEnabled: &watchdogEnabled,
WatchdogIdleEnabled: &watchdogIdle,
WatchdogBusyEnabled: &watchdogBusy,
WatchdogIdleTimeout: &idleTimeout,
WatchdogBusyTimeout: &busyTimeout,
WatchdogInterval: &watchdogInterval,
SingleBackend: &singleBackend,
MaxActiveBackends: &maxActiveBackends,
ParallelBackendRequests: &parallelBackendRequests,
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
Threads: &threads,
ContextSize: &contextSize,
F16: &f16,
Debug: &debug,
TracingMaxItems: &tracingMaxItems,
EnableTracing: &enableTracing,
EnableBackendLogging: &enableBackendLogging,
CORS: &cors,
CSRF: &csrf,
CORSAllowOrigins: &corsAllowOrigins,
P2PToken: &p2pToken,
P2PNetworkID: &p2pNetworkID,
Federated: &federated,
Galleries: &galleries,
BackendGalleries: &backendGalleries,
AutoloadGalleries: &autoloadGalleries,
AutoloadBackendGalleries: &autoloadBackendGalleries,
ApiKeys: &apiKeys,
AgentJobRetentionDays: &agentJobRetentionDays,
OpenResponsesStoreTTL: &openResponsesStoreTTL,
WatchdogEnabled: &watchdogEnabled,
WatchdogIdleEnabled: &watchdogIdle,
WatchdogBusyEnabled: &watchdogBusy,
WatchdogIdleTimeout: &idleTimeout,
WatchdogBusyTimeout: &busyTimeout,
WatchdogInterval: &watchdogInterval,
SingleBackend: &singleBackend,
MaxActiveBackends: &maxActiveBackends,
ParallelBackendRequests: &parallelBackendRequests,
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
Threads: &threads,
ContextSize: &contextSize,
F16: &f16,
Debug: &debug,
TracingMaxItems: &tracingMaxItems,
EnableTracing: &enableTracing,
EnableBackendLogging: &enableBackendLogging,
CORS: &cors,
CSRF: &csrf,
CORSAllowOrigins: &corsAllowOrigins,
P2PToken: &p2pToken,
P2PNetworkID: &p2pNetworkID,
Federated: &federated,
Galleries: &galleries,
BackendGalleries: &backendGalleries,
AutoloadGalleries: &autoloadGalleries,
AutoloadBackendGalleries: &autoloadBackendGalleries,
ApiKeys: &apiKeys,
AgentJobRetentionDays: &agentJobRetentionDays,
OpenResponsesStoreTTL: &openResponsesStoreTTL,
AgentPoolEnabled: &agentPoolEnabled,
AgentPoolDefaultModel: &agentPoolDefaultModel,
AgentPoolEmbeddingModel: &agentPoolEmbeddingModel,

View File

@@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true,
Debug: true,
CORS: true,
DisableCSRF: true,
DisableCSRF: true,
CORSAllowOrigins: "https://example.com",
P2PToken: "test-token",
P2PNetworkID: "test-network",
@@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true,
Debug: false,
CORS: true,
DisableCSRF: false,
DisableCSRF: false,
CORSAllowOrigins: "https://test.com",
P2PToken: "round-trip-token",
P2PNetworkID: "round-trip-network",

View File

@@ -0,0 +1,188 @@
package config
import (
"cmp"
"fmt"
"time"
"github.com/mudler/xlog"
)
// DistributedConfig holds configuration for horizontal scaling mode.
// When Enabled is true, PostgreSQL and NATS are required.
type DistributedConfig struct {
Enabled bool // --distributed / LOCALAI_DISTRIBUTED
InstanceID string // --instance-id / LOCALAI_INSTANCE_ID (auto-generated UUID if empty)
NatsURL string // --nats-url / LOCALAI_NATS_URL
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// S3 configuration (used when StorageURL is set)
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
StorageAccessKey string // --storage-access-key / LOCALAI_STORAGE_ACCESS_KEY
StorageSecretKey string // --storage-secret-key / LOCALAI_STORAGE_SECRET_KEY
// Timeout configuration (all have sensible defaults — zero means use default)
MCPToolTimeout time.Duration // MCP tool execution timeout (default 360s)
MCPDiscoveryTimeout time.Duration // MCP discovery timeout (default 60s)
WorkerWaitTimeout time.Duration // Max wait for healthy worker at startup (default 5m)
DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s)
HealthCheckInterval time.Duration // Health monitor check interval (default 15s)
StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s)
PerModelHealthCheck bool // Enable per-model backend health checking (default false)
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
}
// Validate checks that the distributed configuration is internally consistent.
// It returns nil if distributed mode is disabled.
func (c DistributedConfig) Validate() error {
if !c.Enabled {
return nil
}
if c.NatsURL == "" {
return fmt.Errorf("distributed mode requires --nats-url / LOCALAI_NATS_URL")
}
// S3 credentials must be paired
if (c.StorageAccessKey != "" && c.StorageSecretKey == "") ||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
}
// Warn about missing registration token (not an error)
if c.RegistrationToken == "" {
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
}
// Check for negative durations
for name, d := range map[string]time.Duration{
"mcp-tool-timeout": c.MCPToolTimeout,
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
"worker-wait-timeout": c.WorkerWaitTimeout,
"drain-timeout": c.DrainTimeout,
"health-check-interval": c.HealthCheckInterval,
"stale-node-threshold": c.StaleNodeThreshold,
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
} {
if d < 0 {
return fmt.Errorf("%s must not be negative", name)
}
}
return nil
}
// Distributed config options
var EnableDistributed = func(o *ApplicationConfig) {
o.Distributed.Enabled = true
}
func WithDistributedInstanceID(id string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.InstanceID = id
}
}
func WithNatsURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsURL = url
}
}
func WithRegistrationToken(token string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.RegistrationToken = token
}
}
func WithStorageURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageURL = url
}
}
func WithStorageBucket(bucket string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageBucket = bucket
}
}
func WithStorageRegion(region string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageRegion = region
}
}
func WithStorageAccessKey(key string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageAccessKey = key
}
}
func WithStorageSecretKey(key string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageSecretKey = key
}
}
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
o.Distributed.AutoApproveNodes = true
}
// Defaults for distributed timeouts.
const (
DefaultMCPToolTimeout = 360 * time.Second
DefaultMCPDiscoveryTimeout = 60 * time.Second
DefaultWorkerWaitTimeout = 5 * time.Minute
DefaultDrainTimeout = 30 * time.Second
DefaultHealthCheckInterval = 15 * time.Second
DefaultStaleNodeThreshold = 60 * time.Second
DefaultMCPCIJobTimeout = 10 * time.Minute
)
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
const DefaultMaxUploadSize int64 = 50 << 30
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
}
// MCPDiscoveryTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) MCPDiscoveryTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPDiscoveryTimeout, DefaultMCPDiscoveryTimeout)
}
// WorkerWaitTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) WorkerWaitTimeoutOrDefault() time.Duration {
return cmp.Or(c.WorkerWaitTimeout, DefaultWorkerWaitTimeout)
}
// DrainTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) DrainTimeoutOrDefault() time.Duration {
return cmp.Or(c.DrainTimeout, DefaultDrainTimeout)
}
// HealthCheckIntervalOrDefault returns the configured interval or the default.
func (c DistributedConfig) HealthCheckIntervalOrDefault() time.Duration {
return cmp.Or(c.HealthCheckInterval, DefaultHealthCheckInterval)
}
// StaleNodeThresholdOrDefault returns the configured threshold or the default.
func (c DistributedConfig) StaleNodeThresholdOrDefault() time.Duration {
return cmp.Or(c.StaleNodeThreshold, DefaultStaleNodeThreshold)
}
// MCPCIJobTimeoutOrDefault returns the configured MCP CI job timeout or the default.
func (c DistributedConfig) MCPCIJobTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPCIJobTimeout, DefaultMCPCIJobTimeout)
}
// MaxUploadSizeOrDefault returns the configured max upload size or the default.
func (c DistributedConfig) MaxUploadSizeOrDefault() int64 {
return cmp.Or(c.MaxUploadSize, DefaultMaxUploadSize)
}

View File

@@ -46,11 +46,11 @@ type ModelConfig struct {
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]any `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
@@ -105,6 +105,11 @@ type AgentConfig struct {
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
}
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
func (c MCPConfig) HasMCPServers() bool {
return c.Servers != "" || c.Stdio != ""
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]
@@ -619,15 +624,32 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
// Backends that are clearly not text-generation
nonTextGenBackends := []string{
"whisper", "piper", "kokoro",
"diffusers", "stablediffusion", "stablediffusion-ggml",
"rerankers", "silero-vad", "rfdetr",
"transformers-musicgen", "ace-step", "acestep-cpp",
}
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
if c.Embeddings != nil && *c.Embeddings {
return false
}
}
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
}
if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" {

View File

@@ -1,35 +1,35 @@
package config
import "regexp"
type ModelConfigFilterFn func(string, *ModelConfig) bool
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
if filter == "" {
return NoFilterFn, nil
}
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
return func(name string, config *ModelConfig) bool {
if config != nil {
return rxp.MatchString(config.Name)
}
return rxp.MatchString(name)
}, nil
}
func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
return func(name string, config *ModelConfig) bool {
if config == nil {
return false // TODO: Potentially make this a param, for now, no known usecase to include
}
return config.HasUsecases(usecases)
}
}
package config
import "regexp"
type ModelConfigFilterFn func(string, *ModelConfig) bool
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
if filter == "" {
return NoFilterFn, nil
}
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
return func(name string, config *ModelConfig) bool {
if config != nil {
return rxp.MatchString(config.Name)
}
return rxp.MatchString(name)
}, nil
}
func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
return func(name string, config *ModelConfig) bool {
if config == nil {
return false // TODO: Potentially make this a param, for now, no known usecase to include
}
return config.HasUsecases(usecases)
}
}

View File

@@ -1,12 +1,13 @@
package config
import (
"cmp"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"slices"
"strings"
"sync"
@@ -215,8 +216,8 @@ func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
res = append(res, v)
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Name < res[j].Name
slices.SortStableFunc(res, func(a, b ModelConfig) int {
return cmp.Compare(a.Name, b.Name)
})
return res

View File

@@ -27,15 +27,15 @@ type RuntimeSettings struct {
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
// Eviction settings
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
// Performance settings
Threads *int `json:"threads,omitempty"`
ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"`
Threads *int `json:"threads,omitempty"`
ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"`
EnableTracing *bool `json:"enable_tracing,omitempty"`
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
@@ -66,11 +66,11 @@ type RuntimeSettings struct {
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
// Agent Pool settings
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"`
}

View File

@@ -3,9 +3,10 @@ package explorer
// A simple JSON database for storing and retrieving p2p network tokens and a name and description.
import (
"cmp"
"encoding/json"
"os"
"sort"
"slices"
"sync"
"github.com/gofrs/flock"
@@ -89,9 +90,8 @@ func (db *Database) TokenList() []string {
tokens = append(tokens, k)
}
sort.Slice(tokens, func(i, j int) bool {
// sort by token
return tokens[i] < tokens[j]
slices.SortFunc(tokens, func(a, b string) int {
return cmp.Compare(a, b)
})
return tokens

View File

@@ -15,7 +15,7 @@ import (
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
type modelConfigCacheEntry struct {
configMap map[string]interface{}
configMap map[string]any
lastUpdated time.Time
}
@@ -57,7 +57,7 @@ func resolveBackend(m *GalleryModel, basePath string) string {
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
// inside it, and returns the result as a map. Results are cached for 1 hour.
// Local file:// URLs skip the cache so edits are picked up immediately.
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
func fetchModelConfigMap(modelURL, basePath string) map[string]any {
// Check cache (skip for file:// URLs so local edits are picked up immediately)
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
if !isLocal && modelConfigCache.Exists(modelURL) {
@@ -75,15 +75,15 @@ func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
// Cache the failure for remote URLs to avoid repeated fetch attempts
if !isLocal {
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
configMap: map[string]interface{}{},
configMap: map[string]any{},
lastUpdated: time.Now(),
})
}
return map[string]interface{}{}
return map[string]any{}
}
// Parse the config_file YAML string into a map
configMap := make(map[string]interface{})
configMap := make(map[string]any)
if modelConfig.ConfigFile != "" {
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
@@ -108,13 +108,11 @@ func prefetchModelConfigs(urls []string, basePath string) {
sem := make(chan struct{}, maxConcurrency)
var wg sync.WaitGroup
for _, url := range urls {
wg.Add(1)
go func(u string) {
defer wg.Done()
wg.Go(func() {
sem <- struct{}{}
defer func() { <-sem }()
fetchModelConfigMap(u, basePath)
}(url)
fetchModelConfigMap(url, basePath)
})
}
wg.Wait()
}

View File

@@ -4,10 +4,10 @@ package gallery
import (
"context"
"os"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -20,6 +20,9 @@ import (
cp "github.com/otiai10/copy"
)
// ErrBackendNotFound is returned when a backend is not found in the system.
var ErrBackendNotFound = errors.New("backend not found")
const (
metadataFile = "metadata.json"
runFile = "run.sh"
@@ -198,9 +201,16 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
} else {
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
// Don't remove backendPath here — fallback OCI extractions need the directory to exist
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
// resetBackendPath cleans up partial state from a failed OCI extraction
// so the next download attempt starts fresh. The directory is re-created
// because OCI image extractors need it to exist for writing files into.
resetBackendPath := func() {
os.RemoveAll(backendPath)
os.MkdirAll(backendPath, 0750)
}
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
@@ -210,6 +220,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
return ctx.Err()
default:
}
resetBackendPath()
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
@@ -221,28 +232,22 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
// Try fallback: replace latestTag + "-" with masterTag + "-" in the URI
fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1)
if fallbackURI != string(config.URI) {
xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
resetBackendPath()
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
success = true
} else {
// Try another fallback: add "-" + devSuffix suffix to the backend name
// For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development
xlog.Info("Fallback URI failed", "fallback", fallbackURI, "error", err)
if !strings.Contains(fallbackURI, "-"+devSuffix) {
// Extract backend name from URI and add -development
parts := strings.Split(fallbackURI, "-")
if len(parts) >= 2 {
// Find where the backend name ends (usually the last part before the tag)
// Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step
lastDash := strings.LastIndex(fallbackURI, "-")
if lastDash > 0 {
devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix
xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI)
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
success = true
}
}
resetBackendPath()
devFallbackURI := fallbackURI + "-" + devSuffix
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
success = true
} else {
xlog.Info("Development fallback URI failed", "fallback", devFallbackURI, "error", err)
}
}
}
@@ -295,7 +300,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
backend, ok := backends.Get(name)
if !ok {
return fmt.Errorf("backend %q not found", name)
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
}
if backend.IsSystem {

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"slices"
"strings"
"time"
@@ -106,64 +106,64 @@ func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
}
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool {
if sortOrder == "asc" {
return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
} else {
return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName())
slices.SortFunc(gm, func(a, b T) int {
r := strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
if sortOrder == "desc" {
return -r
}
return r
})
return gm
}
func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool {
if sortOrder == "asc" {
return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name)
} else {
return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name)
slices.SortFunc(gm, func(a, b T) int {
r := strings.Compare(strings.ToLower(a.GetGallery().Name), strings.ToLower(b.GetGallery().Name))
if sortOrder == "desc" {
return -r
}
return r
})
return gm
}
func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool {
licenseI := gm[i].GetLicense()
licenseJ := gm[j].GetLicense()
var result bool
if licenseI == "" && licenseJ != "" {
return sortOrder == "desc"
} else if licenseI != "" && licenseJ == "" {
return sortOrder == "asc"
} else if licenseI == "" && licenseJ == "" {
return false
slices.SortFunc(gm, func(a, b T) int {
licenseA := a.GetLicense()
licenseB := b.GetLicense()
var r int
if licenseA == "" && licenseB != "" {
r = 1
} else if licenseA != "" && licenseB == "" {
r = -1
} else {
result = strings.ToLower(licenseI) < strings.ToLower(licenseJ)
r = strings.Compare(strings.ToLower(licenseA), strings.ToLower(licenseB))
}
if sortOrder == "desc" {
return !result
} else {
return result
return -r
}
return r
})
return gm
}
func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool {
var result bool
slices.SortFunc(gm, func(a, b T) int {
var r int
// Sort by installed status: installed items first (true > false)
if gm[i].GetInstalled() != gm[j].GetInstalled() {
result = gm[i].GetInstalled()
if a.GetInstalled() != b.GetInstalled() {
if a.GetInstalled() {
r = -1
} else {
r = 1
}
} else {
result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
r = strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
}
if sortOrder == "desc" {
return !result
} else {
return result
return -r
}
return r
})
return gm
}

View File

@@ -27,7 +27,7 @@ var _ = Describe("Gallery", func() {
Describe("ReadConfigFile", func() {
It("should read and unmarshal a valid YAML file", func() {
testConfig := map[string]interface{}{
testConfig := map[string]any{
"name": "test-model",
"description": "A test model",
"license": "MIT",
@@ -39,8 +39,8 @@ var _ = Describe("Gallery", func() {
err = os.WriteFile(filePath, yamlData, 0644)
Expect(err).NotTo(HaveOccurred())
var result map[string]interface{}
config, err := ReadConfigFile[map[string]interface{}](filePath)
var result map[string]any
config, err := ReadConfigFile[map[string]any](filePath)
Expect(err).NotTo(HaveOccurred())
Expect(config).NotTo(BeNil())
result = *config
@@ -50,7 +50,7 @@ var _ = Describe("Gallery", func() {
})
It("should return error when file does not exist", func() {
_, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml")
_, err := ReadConfigFile[map[string]any]("nonexistent.yaml")
Expect(err).To(HaveOccurred())
})
@@ -59,7 +59,7 @@ var _ = Describe("Gallery", func() {
err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
Expect(err).NotTo(HaveOccurred())
_, err = ReadConfigFile[map[string]interface{}](filePath)
_, err = ReadConfigFile[map[string]any](filePath)
Expect(err).To(HaveOccurred())
})
})
@@ -552,32 +552,32 @@ var _ = Describe("Gallery", func() {
// Verify first model
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
Expect(models[0].Overrides).NotTo(BeNil())
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
params := models[0].Overrides["parameters"].(map[string]interface{})
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
params := models[0].Overrides["parameters"].(map[string]any)
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
// Verify second model (merged)
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
Expect(models[1].Overrides).NotTo(BeNil())
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
params = models[1].Overrides["parameters"].(map[string]interface{})
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
params = models[1].Overrides["parameters"].(map[string]any)
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
// Simulate the mergo.Merge call that was failing in models.go:251
// This should not panic with yaml.v3
configMap := make(map[string]interface{})
configMap := make(map[string]any)
configMap["name"] = "test"
configMap["backend"] = "llama-cpp"
configMap["parameters"] = map[string]interface{}{
configMap["parameters"] = map[string]any{
"model": "original.gguf",
}
err = mergo.Merge(&configMap, models[1].Overrides, mergo.WithOverride)
Expect(err).NotTo(HaveOccurred())
Expect(configMap["parameters"]).NotTo(BeNil())
// Verify the merge worked correctly
mergedParams := configMap["parameters"].(map[string]interface{})
mergedParams := configMap["parameters"].(map[string]any)
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
})
})

View File

@@ -59,7 +59,7 @@ var _ = Describe("ImportLocalPath", func() {
adapterConfig := map[string]any{
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"peft_type": "LORA",
"peft_type": "LORA",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())

View File

@@ -158,7 +158,7 @@ func InstallModelFromGallery(
return applyModel(model)
}
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]any, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
@@ -239,7 +239,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
configFilePath := filepath.Join(basePath, name+".yaml")
// Read and update config file as map[string]interface{}
configMap := make(map[string]interface{})
configMap := make(map[string]any)
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)

View File

@@ -35,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -43,7 +43,7 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred())
}
content := map[string]interface{}{}
content := map[string]any{}
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
Expect(err).ToNot(HaveOccurred())
@@ -95,7 +95,7 @@ var _ = Describe("Model test", func() {
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
content := map[string]any{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@@ -130,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -150,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -158,7 +158,7 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred())
}
content := map[string]interface{}{}
content := map[string]any{}
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
@@ -180,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})

View File

@@ -12,9 +12,9 @@ import (
type GalleryModel struct {
Metadata `json:",inline" yaml:",inline"`
// config_file is read in the situation where URL is blank - and therefore this is a base config.
ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"`
ConfigFile map[string]any `json:"config_file,omitempty" yaml:"config_file,omitempty"`
// Overrides are used to override the configuration of the model located at URL
Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"`
Overrides map[string]any `json:"overrides,omitempty" yaml:"overrides,omitempty"`
}
func (m *GalleryModel) GetInstalled() bool {

66
core/gallery/worker.go Normal file
View File

@@ -0,0 +1,66 @@
package gallery
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/xlog"
)
// DeleteStagedModelFiles removes all staged files for a model from a worker's
// models directory. Files are expected to be in a subdirectory named after the
// model's tracking key (created by stageModelFiles in the router).
//
// Workers receive model files via S3/HTTP file staging, not gallery install,
// so they lack the YAML configs that DeleteModelFromSystem requires.
//
// Falls back to glob-based cleanup for single-file models or legacy layouts.
func DeleteStagedModelFiles(modelsPath, modelName string) error {
if modelName == "" {
return fmt.Errorf("empty model name")
}
// Clean and validate: resolved path must stay within modelsPath
modelPath := filepath.Clean(filepath.Join(modelsPath, modelName))
absModels := filepath.Clean(modelsPath)
if !strings.HasPrefix(modelPath, absModels+string(filepath.Separator)) {
return fmt.Errorf("model name %q escapes models directory", modelName)
}
// Primary: remove the model's subdirectory (contains all staged files)
if info, err := os.Stat(modelPath); err == nil && info.IsDir() {
return os.RemoveAll(modelPath)
}
// Fallback for single-file models or legacy layouts:
// remove exact file match + glob siblings
removed := false
if _, err := os.Stat(modelPath); err == nil {
if err := os.Remove(modelPath); err != nil {
xlog.Warn("Failed to remove model file", "path", modelPath, "error", err)
} else {
removed = true
}
}
// Remove sibling files (e.g., model.gguf.mmproj alongside model.gguf)
matches, _ := filepath.Glob(modelPath + ".*")
for _, m := range matches {
clean := filepath.Clean(m)
if !strings.HasPrefix(clean, absModels+string(filepath.Separator)) {
continue // skip any glob result that escapes
}
if err := os.Remove(clean); err != nil {
xlog.Warn("Failed to remove model-related file", "path", clean, "error", err)
} else {
removed = true
}
}
if !removed {
xlog.Debug("No files found to delete for model", "model", modelName, "path", modelPath)
}
return nil
}

View File

@@ -0,0 +1,99 @@
package gallery_test
import (
"os"
"path/filepath"
"testing"
"github.com/mudler/LocalAI/core/gallery"
)
func TestDeleteStagedModelFiles(t *testing.T) {
t.Run("rejects empty model name", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "")
if err == nil {
t.Fatal("expected error for empty model name")
}
})
t.Run("rejects path traversal via ..", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "../../etc/passwd")
if err == nil {
t.Fatal("expected error for path traversal attempt")
}
})
t.Run("rejects path traversal via ../foo", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "../foo")
if err == nil {
t.Fatal("expected error for path traversal attempt")
}
})
t.Run("removes model subdirectory with all files", func(t *testing.T) {
dir := t.TempDir()
modelDir := filepath.Join(dir, "my-model", "sd-cpp", "models")
if err := os.MkdirAll(modelDir, 0o755); err != nil {
t.Fatal(err)
}
// Create model files in subdirectory
os.WriteFile(filepath.Join(modelDir, "flux.gguf"), []byte("model"), 0o644)
os.WriteFile(filepath.Join(modelDir, "flux.gguf.mmproj"), []byte("mmproj"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "my-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Entire my-model directory should be gone
if _, err := os.Stat(filepath.Join(dir, "my-model")); !os.IsNotExist(err) {
t.Fatal("expected model directory to be removed")
}
})
t.Run("removes single file model", func(t *testing.T) {
dir := t.TempDir()
modelFile := filepath.Join(dir, "model.gguf")
os.WriteFile(modelFile, []byte("model"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
t.Fatal("expected model file to be removed")
}
})
t.Run("removes sibling files via glob", func(t *testing.T) {
dir := t.TempDir()
modelFile := filepath.Join(dir, "model.gguf")
siblingFile := filepath.Join(dir, "model.gguf.mmproj")
os.WriteFile(modelFile, []byte("model"), 0o644)
os.WriteFile(siblingFile, []byte("mmproj"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
t.Fatal("expected model file to be removed")
}
if _, err := os.Stat(siblingFile); !os.IsNotExist(err) {
t.Fatal("expected sibling file to be removed")
}
})
t.Run("no error when model does not exist", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "nonexistent")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}

View File

@@ -16,12 +16,17 @@ import (
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/finetune"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/monitoring"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/quantization"
"github.com/mudler/xlog"
)
@@ -155,7 +160,7 @@ func API(application *application.Application) (*echo.Echo, error) {
// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
metricsService, err := monitoring.NewLocalAIMetricsService()
if err != nil {
return nil, err
}
@@ -295,9 +300,9 @@ func API(application *application.Application) (*echo.Echo, error) {
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
var opcache *galleryop.OpCache
if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
opcache = galleryop.NewOpCache(application.GalleryService())
}
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
@@ -305,22 +310,51 @@ func API(application *application.Application) (*echo.Echo, error) {
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
// Fine-tuning routes
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
ftService := services.NewFineTuneService(
ftService := finetune.NewFineTuneService(
application.ApplicationConfig(),
application.ModelLoader(),
application.ModelConfigLoader(),
)
if d := application.Distributed(); d != nil {
ftService.SetNATSClient(d.Nats)
if d.DistStores != nil && d.DistStores.FineTune != nil {
ftService.SetFineTuneStore(d.DistStores.FineTune)
}
}
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
// Quantization routes
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
qService := services.NewQuantizationService(
qService := quantization.NewQuantizationService(
application.ApplicationConfig(),
application.ModelLoader(),
application.ModelConfigLoader(),
)
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
// Node management routes (distributed mode)
distCfg := application.ApplicationConfig().Distributed
var registry *nodes.NodeRegistry
var remoteUnloader nodes.NodeCommandSender
if d := application.Distributed(); d != nil {
registry = d.Registry
if d.Router != nil {
remoteUnloader = d.Router.Unloader()
}
}
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
// Distributed SSE routes (job progress + agent events via NATS)
if d := application.Distributed(); d != nil {
if d.Dispatcher != nil {
e.GET("/api/agent/jobs/:id/progress", d.Dispatcher.SSEHandler(), mcpJobsMw)
}
if d.AgentBridge != nil {
e.GET("/api/agents/:name/sse/distributed", d.AgentBridge.SSEHandler(), agentsMw)
}
}
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)

View File

@@ -44,14 +44,14 @@ Say hello.
### Response:`
type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"`
ConfigURL string `json:"config_url"`
Name string `json:"name"`
Overrides map[string]interface{} `json:"overrides"`
ID string `json:"id"`
URL string `json:"url"`
ConfigURL string `json:"config_url"`
Name string `json:"name"`
Overrides map[string]any `json:"overrides"`
}
func getModelStatus(url string) (response map[string]interface{}) {
func getModelStatus(url string) (response map[string]any) {
// Create the HTTP request
req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")
@@ -94,7 +94,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
return response, err
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]any) {
//url := "http://localhost:AI/models/apply"
@@ -336,7 +336,7 @@ var _ = Describe("API test", func() {
Name: "bert",
URL: bertEmbeddingsURL,
},
Overrides: map[string]interface{}{"backend": "llama-cpp"},
Overrides: map[string]any{"backend": "llama-cpp"},
},
{
Metadata: gallery.Metadata{
@@ -344,7 +344,7 @@ var _ = Describe("API test", func() {
URL: bertEmbeddingsURL,
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
},
Overrides: map[string]interface{}{"foo": "bar"},
Overrides: map[string]any{"foo": "bar"},
},
}
out, err := yaml.Marshal(g)
@@ -464,7 +464,7 @@ var _ = Describe("API test", func() {
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
resp := map[string]interface{}{}
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
@@ -479,7 +479,7 @@ var _ = Describe("API test", func() {
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
content := map[string]any{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@@ -503,7 +503,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]interface{}{
Overrides: map[string]any{
"backend": "llama",
},
})
@@ -520,7 +520,7 @@ var _ = Describe("API test", func() {
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
content := map[string]any{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("llama"))
@@ -529,7 +529,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]interface{}{},
Overrides: map[string]any{},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@@ -544,7 +544,7 @@ var _ = Describe("API test", func() {
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
content := map[string]any{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@@ -586,7 +586,7 @@ parameters:
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]interface{}{}
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
@@ -601,7 +601,7 @@ parameters:
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
content := map[string]any{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["name"]).To(Equal("test-import-model"))
@@ -657,7 +657,7 @@ parameters:
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]interface{}{}
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
@@ -1248,7 +1248,7 @@ parameters:
Context("Agent Jobs", Label("agent-jobs"), func() {
It("creates and manages tasks", func() {
// Create a task
taskBody := map[string]interface{}{
taskBody := map[string]any{
"name": "Test Task",
"description": "Test Description",
"model": "testmodel.ggml",
@@ -1256,7 +1256,7 @@ parameters:
"enabled": true,
}
var createResp map[string]interface{}
var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
Expect(createResp["id"]).ToNot(BeEmpty())
@@ -1302,20 +1302,20 @@ parameters:
It("executes and monitors jobs", func() {
// Create a task first
taskBody := map[string]interface{}{
taskBody := map[string]any{
"name": "Job Test Task",
"model": "testmodel.ggml",
"prompt": "Say hello",
"enabled": true,
}
var createResp map[string]interface{}
var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
taskID := createResp["id"].(string)
// Execute a job
jobBody := map[string]interface{}{
jobBody := map[string]any{
"task_id": taskID,
"parameters": map[string]string{},
}
@@ -1357,14 +1357,14 @@ parameters:
It("executes task by name", func() {
// Create a task with a specific name
taskBody := map[string]interface{}{
taskBody := map[string]any{
"name": "Named Task",
"model": "testmodel.ggml",
"prompt": "Hello",
"enabled": true,
}
var createResp map[string]interface{}
var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())

View File

@@ -516,6 +516,17 @@ func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
return true
}
// Node self-service endpoints — authenticated via registration token, not global auth.
// Only exempt the specific known endpoints, not the entire prefix.
if strings.HasPrefix(path, "/api/node/") {
if path == "/api/node/register" ||
strings.HasSuffix(path, "/heartbeat") ||
strings.HasSuffix(path, "/drain") ||
strings.HasSuffix(path, "/deregister") {
return true
}
}
// Check configured exempt paths
for _, p := range appConfig.PathWithoutAuth {
if strings.HasPrefix(path, p) {
@@ -540,6 +551,14 @@ func isAPIPath(path string) bool {
strings.HasPrefix(path, "/system") ||
strings.HasPrefix(path, "/ws/") ||
strings.HasPrefix(path, "/generated-") ||
strings.HasPrefix(path, "/chat/") ||
strings.HasPrefix(path, "/completions") ||
strings.HasPrefix(path, "/edits") ||
strings.HasPrefix(path, "/embeddings") ||
strings.HasPrefix(path, "/audio/") ||
strings.HasPrefix(path, "/images/") ||
strings.HasPrefix(path, "/messages") ||
strings.HasPrefix(path, "/responses") ||
path == "/metrics"
}

View File

@@ -9,24 +9,25 @@ import (
// Auth provider constants.
const (
ProviderLocal = "local"
ProviderGitHub = "github"
ProviderOIDC = "oidc"
ProviderLocal = "local"
ProviderGitHub = "github"
ProviderOIDC = "oidc"
ProviderAgentWorker = "agent-worker"
)
// User represents an authenticated user.
type User struct {
ID string `gorm:"primaryKey;size:36"`
Email string `gorm:"size:255;index"`
Name string `gorm:"size:255"`
AvatarURL string `gorm:"size:512"`
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
Subject string `gorm:"size:255"` // provider-specific user ID
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
ID string `gorm:"primaryKey;size:36"`
Email string `gorm:"size:255;index"`
Name string `gorm:"size:255"`
AvatarURL string `gorm:"size:512"`
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
Subject string `gorm:"size:255"` // provider-specific user ID
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
Role string `gorm:"size:20;default:user"`
Status string `gorm:"size:20;default:active"` // "active", "pending"
CreatedAt time.Time
UpdatedAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// Session represents a user login session.
@@ -90,16 +91,16 @@ func (p *PermissionMap) Scan(value any) error {
// InviteCode represents an admin-generated invitation for user registration.
type InviteCode struct {
ID string `gorm:"primaryKey;size:36"`
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
CreatedBy string `gorm:"size:36;not null"`
UsedBy *string `gorm:"size:36"`
UsedAt *time.Time
ExpiresAt time.Time `gorm:"not null;index"`
CreatedAt time.Time
Creator User `gorm:"foreignKey:CreatedBy"`
Consumer *User `gorm:"foreignKey:UsedBy"`
ID string `gorm:"primaryKey;size:36"`
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
CreatedBy string `gorm:"size:36;not null"`
UsedBy *string `gorm:"size:36"`
UsedAt *time.Time
ExpiresAt time.Time `gorm:"not null;index"`
CreatedAt time.Time
Creator User `gorm:"foreignKey:CreatedBy"`
Consumer *User `gorm:"foreignKey:UsedBy"`
}
// ModelAllowlist controls which models a user can access.

View File

@@ -33,24 +33,24 @@ const (
FeatureMCPJobs = "mcp_jobs"
// General features (default OFF for new users)
FeatureFineTuning = "fine_tuning"
FeatureQuantization = "quantization"
FeatureFineTuning = "fine_tuning"
FeatureQuantization = "quantization"
// API features (default ON for new users)
FeatureChat = "chat"
FeatureImages = "images"
FeatureAudioSpeech = "audio_speech"
FeatureChat = "chat"
FeatureImages = "images"
FeatureAudioSpeech = "audio_speech"
FeatureAudioTranscription = "audio_transcription"
FeatureVAD = "vad"
FeatureDetection = "detection"
FeatureVideo = "video"
FeatureEmbeddings = "embeddings"
FeatureSound = "sound"
FeatureRealtime = "realtime"
FeatureRerank = "rerank"
FeatureTokenize = "tokenize"
FeatureMCP = "mcp"
FeatureStores = "stores"
FeatureVAD = "vad"
FeatureDetection = "detection"
FeatureVideo = "video"
FeatureEmbeddings = "embeddings"
FeatureSound = "sound"
FeatureRealtime = "realtime"
FeatureRerank = "rerank"
FeatureTokenize = "tokenize"
FeatureMCP = "mcp"
FeatureStores = "stores"
)
// AgentFeatures lists agent-related features (default OFF).

View File

@@ -24,14 +24,14 @@ type QuotaRule struct {
// QuotaStatus is returned to clients with current usage included.
type QuotaStatus struct {
ID string `json:"id"`
Model string `json:"model"`
MaxRequests *int64 `json:"max_requests"`
MaxTotalTokens *int64 `json:"max_total_tokens"`
Window string `json:"window"`
CurrentRequests int64 `json:"current_requests"`
CurrentTokens int64 `json:"current_total_tokens"`
ResetsAt string `json:"resets_at,omitempty"`
ID string `json:"id"`
Model string `json:"model"`
MaxRequests *int64 `json:"max_requests"`
MaxTotalTokens *int64 `json:"max_total_tokens"`
Window string `json:"window"`
CurrentRequests int64 `json:"current_requests"`
CurrentTokens int64 `json:"current_total_tokens"`
ResetsAt string `json:"resets_at,omitempty"`
}
// ── CRUD ──
@@ -209,9 +209,9 @@ func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) {
var quotaCache = newQuotaCacheStore()
type quotaCacheStore struct {
mu sync.RWMutex
rules map[string]cachedRules // userID -> rules
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
mu sync.RWMutex
rules map[string]cachedRules // userID -> rules
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
}
type cachedRules struct {

View File

@@ -13,7 +13,7 @@ import (
const (
sessionDuration = 30 * 24 * time.Hour // 30 days
sessionIDBytes = 32 // 32 bytes = 64 hex chars
sessionIDBytes = 32 // 32 bytes = 64 hex chars
sessionCookie = "session"
sessionRotationInterval = 1 * time.Hour
)

View File

@@ -10,15 +10,15 @@ import (
// UsageRecord represents a single API request's token usage.
type UsageRecord struct {
ID uint `gorm:"primaryKey;autoIncrement"`
UserID string `gorm:"size:36;index:idx_usage_user_time"`
UserName string `gorm:"size:255"`
Model string `gorm:"size:255;index"`
Endpoint string `gorm:"size:255"`
ID uint `gorm:"primaryKey;autoIncrement"`
UserID string `gorm:"size:36;index:idx_usage_user_time"`
UserName string `gorm:"size:255"`
Model string `gorm:"size:255;index"`
Endpoint string `gorm:"size:255"`
PromptTokens int64
CompletionTokens int64
TotalTokens int64
Duration int64 // milliseconds
Duration int64 // milliseconds
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
}
@@ -127,10 +127,10 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", model, user_id, user_name, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
Select(bucketExpr + ", model, user_id, user_name, " +
"SUM(prompt_tokens) as prompt_tokens, " +
"SUM(completion_tokens) as completion_tokens, " +
"SUM(total_tokens) as total_tokens, " +
"COUNT(*) as request_count").
Group("bucket, model, user_id, user_name").
Order("bucket ASC")

View File

@@ -36,7 +36,7 @@ var _ = Describe("Usage", func() {
db := testDB()
// Insert records for two users
for i := 0; i < 3; i++ {
for range 3 {
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-a",
UserName: "Alice",

View File

@@ -3,7 +3,6 @@ package anthropic
import (
"encoding/json"
"fmt"
"strings"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
@@ -25,7 +24,7 @@ import (
// @Param request body schema.AnthropicRequest true "query params"
// @Success 200 {object} schema.AnthropicResponse "Response"
// @Router /v1/messages [post]
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
return func(c echo.Context) error {
id := uuid.New().String()
@@ -52,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
funcs, shouldUseFn := convertAnthropicTools(input, cfg)
// MCP injection: prompts, resources, and tools
var mcpToolInfos []mcpTools.MCPToolInfo
var mcpExecutor mcpTools.ToolExecutor
mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata)
mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata)
mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata)
@@ -60,76 +59,29 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") {
remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML()
if mcpErr == nil {
mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, mcpServers)
// Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode)
namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers)
if sessErr == nil && len(namedSessions) > 0 {
// Prompt injection
if mcpPromptName != "" {
prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions)
if discErr == nil {
promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs)
if getErr == nil {
var injected []schema.Message
for _, pm := range promptMsgs {
injected = append(injected, schema.Message{
Role: string(pm.Role),
Content: mcpTools.PromptMessageToText(pm),
})
}
openAIMessages = append(injected, openAIMessages...)
xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected))
} else {
xlog.Error("Failed to get MCP prompt", "error", getErr)
}
}
mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs)
if mcpCtx != nil {
openAIMessages = append(mcpCtx.PromptMessages, openAIMessages...)
mcpTools.AppendResourceSuffix(openAIMessages, mcpCtx.ResourceSuffix)
}
}
// Resource injection
if len(mcpResourceURIs) > 0 {
resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions)
if discErr == nil {
var resourceTexts []string
for _, uri := range mcpResourceURIs {
content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri)
if readErr != nil {
xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri)
continue
}
name := uri
for _, r := range resources {
if r.URI == uri {
name = r.Name
break
}
}
resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content))
}
if len(resourceTexts) > 0 && len(openAIMessages) > 0 {
lastIdx := len(openAIMessages) - 1
suffix := "\n\n" + strings.Join(resourceTexts, "\n\n")
switch ct := openAIMessages[lastIdx].Content.(type) {
case string:
openAIMessages[lastIdx].Content = ct + suffix
default:
openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix)
}
xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts))
}
}
}
// Tool injection
if len(mcpServers) > 0 {
discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions)
if discErr == nil {
mcpToolInfos = discovered
for _, ti := range mcpToolInfos {
funcs = append(funcs, ti.Function)
}
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs))
} else {
xlog.Error("Failed to discover MCP tools", "error", discErr)
// Tool injection via executor
if mcpExecutor.HasTools() {
mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context())
if discErr == nil {
for _, fn := range mcpFuncs {
funcs = append(funcs, fn)
}
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs))
} else {
xlog.Error("Failed to discover MCP tools", "error", discErr)
}
}
} else {
@@ -177,19 +129,19 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
if input.Stream {
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
}
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
}
}
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
mcpMaxIterations := 10
if cfg.Agent.MaxIterations > 0 {
mcpMaxIterations = cfg.Agent.MaxIterations
}
hasMCPTools := len(mcpToolInfos) > 0
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
// Re-template on each MCP iteration since messages may have changed
@@ -227,7 +179,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
if hasMCPTools && shouldUseFn && len(toolCalls) > 0 {
var hasMCPCalls bool
for _, tc := range toolCalls {
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
hasMCPCalls = true
break
}
@@ -257,13 +209,12 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
// Execute each MCP tool call and append results
for _, tc := range assistantMsg.ToolCalls {
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
continue
}
xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
c.Request().Context(), mcpToolInfos,
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
toolResult, toolErr := mcpExecutor.ExecuteTool(
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
)
if toolErr != nil {
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
@@ -290,10 +241,10 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
if shouldUseFn && len(toolCalls) > 0 {
stopReason = "tool_use"
for _, tc := range toolCalls {
var inputArgs map[string]interface{}
var inputArgs map[string]any
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
inputArgs = map[string]interface{}{"raw": tc.Arguments}
inputArgs = map[string]any{"raw": tc.Arguments}
}
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
Type: "tool_use",
@@ -316,9 +267,9 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
}
for i, fc := range parsed {
var inputArgs map[string]interface{}
var inputArgs map[string]any
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
inputArgs = map[string]interface{}{"raw": fc.Arguments}
inputArgs = map[string]any{"raw": fc.Arguments}
}
toolCallID := fc.ID
if toolCallID == "" {
@@ -365,7 +316,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
}
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
@@ -388,7 +339,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
if cfg.Agent.MaxIterations > 0 {
mcpMaxIterations = cfg.Agent.MaxIterations
}
hasMCPTools := len(mcpToolInfos) > 0
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
// Re-template on MCP iterations
@@ -483,7 +434,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
if err != nil {
xlog.Error("Anthropic stream model inference failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
Type: "error",
Error: &schema.AnthropicError{
Type: "api_error",
Message: fmt.Sprintf("model inference failed: %v", err),
},
})
return nil
}
// Also check chat deltas for tool calls
@@ -495,7 +453,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
if hasMCPTools && len(collectedToolCalls) > 0 {
var hasMCPCalls bool
for _, tc := range collectedToolCalls {
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
hasMCPCalls = true
break
}
@@ -525,13 +483,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
// Execute MCP tool calls
for _, tc := range assistantMsg.ToolCalls {
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
continue
}
xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
c.Request().Context(), mcpToolInfos,
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
toolResult, toolErr := mcpExecutor.ExecuteTool(
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
)
if toolErr != nil {
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
@@ -686,7 +643,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
case string:
openAIMsg.StringContent = content
openAIMsg.Content = content
case []interface{}:
case []any:
// Handle array of content blocks
var textContent string
var stringImages []string
@@ -694,7 +651,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
toolCallIndex := 0
for _, block := range content {
if blockMap, ok := block.(map[string]interface{}); ok {
if blockMap, ok := block.(map[string]any); ok {
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
@@ -703,7 +660,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
}
case "image":
// Handle image content
if source, ok := blockMap["source"].(map[string]interface{}); ok {
if source, ok := blockMap["source"].(map[string]any); ok {
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
if data, ok := source["data"].(string); ok {
mediaType, _ := source["media_type"].(string)
@@ -718,14 +675,14 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
toolID, _ := blockMap["id"].(string)
toolName, _ := blockMap["name"].(string)
toolInput := blockMap["input"]
// Serialize input to JSON string
inputJSON, err := json.Marshal(toolInput)
if err != nil {
xlog.Warn("Failed to marshal tool input", "error", err)
inputJSON = []byte("{}")
}
toolCalls = append(toolCalls, schema.ToolCall{
Index: toolCallIndex,
ID: toolID,
@@ -745,16 +702,16 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil {
isError = *isErrorPtr
}
var resultText string
if resultContent, ok := blockMap["content"]; ok {
switch rc := resultContent.(type) {
case string:
resultText = rc
case []interface{}:
case []any:
// Array of content blocks
for _, cb := range rc {
if cbMap, ok := cb.(map[string]interface{}); ok {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultText += text
@@ -764,7 +721,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
}
}
}
// Add tool result as a tool role message
// We need to handle this differently - create a new message
if msg.Role == "user" {
@@ -781,7 +738,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
openAIMsg.StringContent = textContent
openAIMsg.Content = textContent
openAIMsg.StringImages = stringImages
// Add tool calls if present
if len(toolCalls) > 0 {
openAIMsg.ToolCalls = toolCalls
@@ -799,7 +756,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
if len(input.Tools) == 0 {
return nil, false
}
var funcs functions.Functions
for _, tool := range input.Tools {
f := functions.Function{
@@ -809,7 +766,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
}
funcs = append(funcs, f)
}
// Handle tool_choice
if input.ToolChoice != nil {
switch tc := input.ToolChoice.(type) {
@@ -823,7 +780,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
return nil, false
}
// "auto" is the default - let model decide
case map[string]interface{}:
case map[string]any:
// Specific tool selection: {"type": "tool", "name": "tool_name"}
if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
if name, ok := tc["name"].(string); ok {
@@ -833,6 +790,6 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
}
}
}
return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
}

View File

@@ -1,9 +1,10 @@
package explorer
import (
"cmp"
"encoding/base64"
"net/http"
"sort"
"slices"
"strings"
"github.com/labstack/echo/v4"
@@ -14,7 +15,7 @@ import (
func Dashboard() echo.HandlerFunc {
return func(c echo.Context) error {
summary := map[string]interface{}{
summary := map[string]any{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c),
@@ -61,8 +62,8 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
}
// order by number of clusters
sort.Slice(results, func(i, j int) bool {
return len(results[i].Clusters) > len(results[j].Clusters)
slices.SortFunc(results, func(a, b Network) int {
return cmp.Compare(len(b.Clusters), len(a.Clusters))
})
return c.JSON(http.StatusOK, results)
@@ -73,36 +74,36 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
request := new(AddNetworkRequest)
if err := c.Bind(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Cannot parse JSON"})
}
if request.Token == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token is required"})
}
if request.Name == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Name is required"})
}
if request.Description == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"})
}
// TODO: check if token is valid, otherwise reject
// try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"})
}
if _, exists := db.Get(request.Token); exists {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token already exists"})
}
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "Cannot add token"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
return c.JSON(http.StatusOK, map[string]any{"message": "Token added"})
}
}

View File

@@ -1,6 +1,7 @@
package localai
import (
"errors"
"fmt"
"net/http"
"strconv"
@@ -8,12 +9,12 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/services/agentpool"
)
// getJobService returns the job service for the current user.
// Falls back to the global service when no user is authenticated.
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService {
func getJobService(app *application.Application, c echo.Context) *agentpool.AgentJobService {
userID := getUserID(c)
if userID == "" {
return app.AgentJobService()
@@ -54,7 +55,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
}
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
if err.Error() == "task not found: "+id {
if errors.Is(err, agentpool.ErrTaskNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
@@ -68,7 +69,7 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := getJobService(app, c).DeleteTask(id); err != nil {
if err.Error() == "task not found: "+id {
if errors.Is(err, agentpool.ErrTaskNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
@@ -244,7 +245,7 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := getJobService(app, c).CancelJob(id); err != nil {
if err.Error() == "job not found: "+id {
if errors.Is(err, agentpool.ErrJobNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
@@ -258,7 +259,7 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := getJobService(app, c).DeleteJob(id); err != nil {
if err.Error() == "job not found: "+id {
if errors.Is(err, agentpool.ErrJobNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
@@ -275,7 +276,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
if c.Request().ContentLength > 0 {
if err := c.Bind(&params); err != nil {
body := make(map[string]interface{})
body := make(map[string]any)
if err := c.Bind(&body); err == nil {
params = make(map[string]string)
for k, v := range body {

View File

@@ -2,6 +2,7 @@ package localai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -10,8 +11,9 @@ import (
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
coreTypes "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/xlog"
"github.com/sashabaranov/go-openai"
)
@@ -50,55 +52,105 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
return next(c)
}
// Check if this model name is an agent
ag := svc.GetAgent(req.Model)
if ag == nil {
return next(c)
}
// This is an agent — handle the request directly
// Check if this model name is an agent — try in-process agent first,
// fall back to config lookup (covers distributed mode where agents
// don't run in-process).
messages := parseInputToMessages(req.Input)
if len(messages) == 0 {
return c.JSON(http.StatusBadRequest, map[string]any{
"error": map[string]string{
"type": "invalid_request_error",
"message": "no input messages provided",
},
})
userID := effectiveUserID(c)
ag := svc.GetAgent(req.Model)
if ag == nil && svc.GetAgentConfigForUser(userID, req.Model) == nil {
return next(c) // not an agent
}
jobOptions := []coreTypes.JobOption{
coreTypes.WithConversationHistory(messages),
// Extract the last user message for the executor
var userMessage string
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
userMessage = messages[i].Content
break
}
}
res := ag.Ask(jobOptions...)
var responseText string
if res == nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{
"type": "server_error",
"message": "agent request failed or was cancelled",
},
})
}
if res.Error != nil {
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{
"type": "server_error",
"message": res.Error.Error(),
},
if ag != nil {
// Local mode: use LocalAGI agent directly
jobOptions := []coreTypes.JobOption{
coreTypes.WithConversationHistory(messages),
}
res := ag.Ask(jobOptions...)
if res == nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{
"type": "server_error",
"message": "agent request failed or was cancelled",
},
})
}
if res.Error != nil {
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{
"type": "server_error",
"message": res.Error.Error(),
},
})
}
responseText = res.Response
} else {
// Distributed mode: dispatch via NATS + wait for response synchronously
var bridge *agents.EventBridge
if d := app.Distributed(); d != nil {
bridge = d.AgentBridge
}
if bridge == nil {
return next(c)
}
// Subscribe BEFORE dispatching so we never miss a fast response
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Minute)
defer cancel()
responseCh := make(chan string, 1)
sub, err := bridge.SubscribeEvents(req.Model, userID, func(evt agents.AgentEvent) {
if evt.EventType == "json_message" && evt.Sender == "agent" {
responseCh <- evt.Content
}
})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{"type": "server_error", "message": "failed to subscribe to agent events"},
})
}
defer sub.Unsubscribe()
// Now dispatch via ChatForUser (publishes to NATS)
_, err = svc.ChatForUser(userID, req.Model, userMessage)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{"type": "server_error", "message": err.Error()},
})
}
select {
case responseText = <-responseCh:
// Got the response
case <-ctx.Done():
return c.JSON(http.StatusGatewayTimeout, map[string]any{
"error": map[string]string{"type": "server_error", "message": "agent response timeout"},
})
}
}
id := fmt.Sprintf("resp_%s", uuid.New().String())
return c.JSON(http.StatusOK, map[string]any{
"id": id,
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": req.Model,
"id": id,
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": req.Model,
"previous_response_id": nil,
"output": []any{
map[string]any{
@@ -109,7 +161,7 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
"content": []map[string]any{
{
"type": "output_text",
"text": res.Response,
"text": responseText,
"annotations": []any{},
},
},

View File

@@ -7,6 +7,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
skillsManager "github.com/mudler/LocalAI/core/services/skills"
skilldomain "github.com/mudler/skillserver/pkg/domain"
)
@@ -41,27 +42,48 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
return out
}
// getSkillManager returns a SkillManager for the request's user.
func getSkillManager(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
svc := app.AgentPoolService()
userID := getUserID(c)
return svc.SkillManagerForUser(userID)
}
func getSkillManagerEffective(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
return svc.SkillManagerForUser(userID)
}
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
skills, err := svc.ListSkillsForUser(userID)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
skills, err := mgr.List()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Admin cross-user aggregation
if wantsAllUsers(c) {
svc := app.AgentPoolService()
usm := svc.UserServicesManager()
if usm != nil {
userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{}
userID := getUserID(c)
for _, uid := range userIDs {
if uid == userID {
continue
}
userSkills, err := svc.ListSkillsForUser(uid)
if err != nil || len(userSkills) == 0 {
uidMgr, mgrErr := svc.SkillManagerForUser(uid)
if mgrErr != nil {
continue
}
userSkills, listErr := uidMgr.List()
if listErr != nil || len(userSkills) == 0 {
continue
}
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
@@ -76,25 +98,28 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
}
}
return c.JSON(http.StatusOK, skillsToResponses(skills))
return c.JSON(http.StatusOK, map[string]any{"skills": skillsToResponses(skills)})
}
}
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
cfg := svc.GetSkillsConfigForUser(userID)
return c.JSON(http.StatusOK, cfg)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusOK, map[string]string{})
}
return c.JSON(http.StatusOK, mgr.GetConfig())
}
}
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
query := c.QueryParam("q")
skills, err := svc.SearchSkillsForUser(userID, query)
skills, err := mgr.Search(query)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
@@ -104,8 +129,10 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -118,7 +145,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
skill, err := mgr.Create(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
@@ -131,9 +158,11 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
skill, err := svc.GetSkillForUser(userID, c.Param("name"))
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
skill, err := mgr.Get(c.Param("name"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -143,8 +172,10 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct {
Description string `json:"description"`
Content string `json:"content"`
@@ -156,7 +187,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
skill, err := mgr.Update(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -169,9 +200,11 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil {
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.Delete(c.Param("name")); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -180,10 +213,12 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
name := c.Param("*")
data, err := svc.ExportSkillForUser(userID, name)
data, err := mgr.Export(name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -195,8 +230,10 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
@@ -210,7 +247,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.ImportSkillForUser(userID, data)
skill, err := mgr.Import(data)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -222,9 +259,11 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name"))
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
resources, skill, err := mgr.ListResources(c.Param("name"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -260,9 +299,11 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*"))
mgr, err := getSkillManagerEffective(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
content, info, err := mgr.GetResource(c.Param("name"), c.Param("*"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -281,10 +322,12 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
file, err := c.FormFile("file")
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
file, fileErr := c.FormFile("file")
if fileErr != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
}
path := c.FormValue("path")
@@ -300,7 +343,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil {
if err := mgr.CreateResource(c.Param("name"), path, data); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"path": path})
@@ -309,15 +352,17 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct {
Content string `json:"content"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil {
if err := mgr.UpdateResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -326,9 +371,11 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil {
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.DeleteResource(c.Param("name"), c.Param("*")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -339,9 +386,11 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
repos, err := svc.ListGitReposForUser(userID)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
repos, err := mgr.ListGitRepos()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
@@ -351,15 +400,17 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct {
URL string `json:"url"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
repo, err := svc.AddGitRepoForUser(userID, payload.URL)
repo, err := mgr.AddGitRepo(payload.URL)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -369,8 +420,10 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct {
URL string `json:"url"`
Enabled *bool `json:"enabled"`
@@ -378,7 +431,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled)
repo, err := mgr.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -391,9 +444,11 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil {
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.DeleteGitRepo(c.Param("id")); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -405,9 +460,11 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil {
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.SyncGitRepo(c.Param("id")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
@@ -416,9 +473,11 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id"))
mgr, err := getSkillManager(c, app)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
repo, err := mgr.ToggleGitRepo(c.Param("id"))
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}

View File

@@ -4,20 +4,23 @@ import (
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"os"
"path/filepath"
"sort"
"slices"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAGI/core/state"
coreTypes "github.com/mudler/LocalAGI/core/types"
agiServices "github.com/mudler/LocalAGI/services"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
)
// getUserID extracts the scoped user ID from the request context.
@@ -42,25 +45,39 @@ func wantsAllUsers(c echo.Context) bool {
}
// effectiveUserID returns the user ID to scope operations to.
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's
// resources. Non-admin callers always get their own ID regardless of query params.
// SECURITY: Only admins and agent-worker service accounts may supply
// ?user_id=<id> to operate on another user's resources. Agent-worker users are
// created exclusively server-side during node registration and need to access
// collections on behalf of the user whose agent they are executing.
// Regular callers always get their own ID regardless of query params.
func effectiveUserID(c echo.Context) string {
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) {
if targetUID := c.QueryParam("user_id"); targetUID != "" && canImpersonateUser(c) {
if callerID := getUserID(c); callerID != targetUID {
xlog.Info("User impersonation", "caller", callerID, "target", targetUID, "path", c.Path())
}
return targetUID
}
return getUserID(c)
}
// canImpersonateUser returns true if the caller is allowed to use ?user_id= to
// scope operations to another user. Allowed for admins and agent-worker service
// accounts (ProviderAgentWorker is set server-side during node registration and
// cannot be self-assigned).
func canImpersonateUser(c echo.Context) bool {
user := auth.GetUser(c)
if user == nil {
return false
}
return user.Role == auth.RoleAdmin || user.Provider == auth.ProviderAgentWorker
}
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
statuses := svc.ListAgentsForUser(userID)
agents := make([]string, 0, len(statuses))
for name := range statuses {
agents = append(agents, name)
}
sort.Strings(agents)
agents := slices.Sorted(maps.Keys(statuses))
resp := map[string]any{
"agents": agents,
"agentCount": len(agents),
@@ -111,13 +128,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
ag := svc.GetAgentForUser(userID, name)
if ag == nil {
statuses := svc.ListAgentsForUser(userID)
active, exists := statuses[name]
if !exists {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
return c.JSON(http.StatusOK, map[string]any{
"active": !ag.Paused(),
})
return c.JSON(http.StatusOK, map[string]any{"active": active})
}
}
@@ -192,9 +209,13 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
history := svc.GetAgentStatusForUser(userID, name)
if history == nil {
history = &state.Status{ActionResults: []coreTypes.ActionState{}}
return c.JSON(http.StatusOK, map[string]any{
"Name": name,
"History": []string{},
})
}
entries := []string{}
for i := len(history.Results()) - 1; i >= 0; i-- {
@@ -221,10 +242,14 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
history, err := svc.GetAgentObservablesForUser(userID, name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
if history == nil {
history = []json.RawMessage{}
}
return c.JSON(http.StatusOK, map[string]any{
"Name": name,
"History": history,
@@ -278,26 +303,30 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
manager := svc.GetSSEManagerForUser(userID, name)
if manager == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
return services.HandleSSE(c, manager)
}
}
type agentConfigMetaResponse struct {
state.AgentConfigMeta
OutputsDir string `json:"OutputsDir"`
// Try local SSE manager first
manager := svc.GetSSEManagerForUser(userID, name)
if manager != nil {
return agentpool.HandleSSE(c, manager)
}
// Fall back to distributed EventBridge SSE
var bridge *agents.EventBridge
if d := app.Distributed(); d != nil {
bridge = d.AgentBridge
}
if bridge != nil {
return bridge.HandleSSE(c, name, userID)
}
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
}
func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
return c.JSON(http.StatusOK, agentConfigMetaResponse{
AgentConfigMeta: svc.GetConfigMeta(),
OutputsDir: svc.OutputsDir(),
})
return c.JSON(http.StatusOK, svc.GetConfigMetaResult())
}
}

Some files were not shown because too many files have changed in this diff Show More