mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat: add distributed mode (#9124)
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
4c870288d9
commit
59108fbe32
2
.github/gallery-agent/agent.go
vendored
2
.github/gallery-agent/agent.go
vendored
@@ -406,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string {
|
||||
}
|
||||
|
||||
// Parse the response to get avatar URL
|
||||
var userInfo map[string]interface{}
|
||||
var userInfo map[string]any
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
40
.github/gallery-agent/testing.go
vendored
40
.github/gallery-agent/testing.go
vendored
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
|
||||
generator := NewSyntheticDataGenerator()
|
||||
|
||||
// Generate a random number of synthetic models (1-3)
|
||||
numModels := generator.rand.Intn(3) + 1
|
||||
numModels := generator.rand.IntN(3) + 1
|
||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||
|
||||
var models []ProcessedModel
|
||||
for i := 0; i < numModels; i++ {
|
||||
for i := range numModels {
|
||||
model := generator.GenerateProcessedModel()
|
||||
models = append(models, model)
|
||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
|
||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||
return &SyntheticDataGenerator{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||
fileTypes := []string{"model", "readme", "other"}
|
||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
||||
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
|
||||
|
||||
var path string
|
||||
var isReadme bool
|
||||
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
|
||||
|
||||
return ProcessedModelFile{
|
||||
Path: path,
|
||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
||||
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
|
||||
SHA256: g.randomSHA256(),
|
||||
IsReadme: isReadme,
|
||||
FileType: fileType,
|
||||
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||
|
||||
author := authors[g.rand.Intn(len(authors))]
|
||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
||||
author := authors[g.rand.IntN(len(authors))]
|
||||
modelName := modelNames[g.rand.IntN(len(modelNames))]
|
||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||
|
||||
// Generate files
|
||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
||||
numFiles := g.rand.IntN(5) + 2 // 2-6 files
|
||||
files := make([]ProcessedModelFile, numFiles)
|
||||
|
||||
// Ensure at least one model file and one readme
|
||||
hasModelFile := false
|
||||
hasReadme := false
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
for i := range numFiles {
|
||||
files[i] = g.GenerateProcessedModelFile()
|
||||
if files[i].FileType == "model" {
|
||||
hasModelFile = true
|
||||
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
|
||||
// Generate sample metadata
|
||||
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
||||
license := licenses[g.rand.Intn(len(licenses))]
|
||||
license := licenses[g.rand.IntN(len(licenses))]
|
||||
|
||||
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
||||
numTags := g.rand.Intn(4) + 3 // 3-6 tags
|
||||
numTags := g.rand.IntN(4) + 3 // 3-6 tags
|
||||
tags := make([]string, numTags)
|
||||
for i := 0; i < numTags; i++ {
|
||||
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
|
||||
for i := range numTags {
|
||||
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
|
||||
}
|
||||
// Remove duplicates
|
||||
tags = g.removeDuplicates(tags)
|
||||
|
||||
// Optionally include icon (50% chance)
|
||||
icon := ""
|
||||
if g.rand.Intn(2) == 0 {
|
||||
if g.rand.IntN(2) == 0 {
|
||||
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
||||
}
|
||||
|
||||
return ProcessedModel{
|
||||
ModelID: modelID,
|
||||
Author: author,
|
||||
Downloads: g.rand.Intn(1000000) + 1000,
|
||||
Downloads: g.rand.IntN(1000000) + 1000,
|
||||
LastModified: g.randomDate(),
|
||||
Files: files,
|
||||
PreferredModelFile: preferredModelFile,
|
||||
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||
const charset = "0123456789abcdef"
|
||||
b := make([]byte, 64)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) randomDate() string {
|
||||
now := time.Now()
|
||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
||||
daysAgo := g.rand.IntN(365) // Random date within last year
|
||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||
}
|
||||
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
|
||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||
}
|
||||
|
||||
return templates[g.rand.Intn(len(templates))]
|
||||
return templates[g.rand.IntN(len(templates))]
|
||||
}
|
||||
|
||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.25.x']
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
@@ -179,7 +179,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.25.x']
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
|
||||
@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
|
||||
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
||||
FROM requirements-drivers AS build-requirements
|
||||
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG GO_VERSION=1.26.0
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG TARGETARCH
|
||||
@@ -319,7 +319,6 @@ COPY ./.git ./.git
|
||||
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
||||
COPY ./pkg/grpc ./pkg/grpc
|
||||
COPY ./pkg/utils ./pkg/utils
|
||||
COPY ./pkg/langchain ./pkg/langchain
|
||||
|
||||
RUN ls -l ./
|
||||
RUN make protogen-go
|
||||
|
||||
@@ -154,6 +154,7 @@ For older news and full release notes, see [GitHub Releases](https://github.com/
|
||||
- [Object Detection](https://localai.io/features/object-detection/)
|
||||
- [Reranker API](https://localai.io/features/reranker/)
|
||||
- [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
|
||||
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
||||
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
||||
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
||||
|
||||
@@ -51,6 +51,7 @@ service Backend {
|
||||
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -676,3 +677,4 @@ message QuantizationProgressUpdate {
|
||||
message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
#include <regex>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <signal.h>
|
||||
#include <thread>
|
||||
@@ -37,6 +39,47 @@ using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
||||
public:
|
||||
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
||||
|
||||
bool IsBlocking() const override { return false; }
|
||||
|
||||
grpc::Status Process(const InputMetadata& auth_metadata,
|
||||
grpc::AuthContext* /*context*/,
|
||||
OutputMetadata* /*consumed_auth_metadata*/,
|
||||
OutputMetadata* /*response_metadata*/) override {
|
||||
auto it = auth_metadata.find("authorization");
|
||||
if (it != auth_metadata.end()) {
|
||||
std::string expected = "Bearer " + token_;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
// Constant-time comparison
|
||||
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
private:
|
||||
std::string token_;
|
||||
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// END LocalAI
|
||||
|
||||
|
||||
@@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
std::shared_ptr<grpc::ServerCredentials> creds;
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
creds->SetAuthMetadataProcessor(
|
||||
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
} else {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
}
|
||||
|
||||
builder.AddListeningPort(server_address, creds);
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
|
||||
@@ -106,10 +106,10 @@ func TestLoadModel(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
|
||||
// Get base directory from main model file for relative paths
|
||||
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
|
||||
|
||||
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
@@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
||||
)
|
||||
|
||||
@@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
||||
var textEncoderModel, ditModel, vaeModel string
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
key, value, found := strings.Cut(oo, ":")
|
||||
if !found {
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
continue
|
||||
}
|
||||
switch parts[0] {
|
||||
switch key {
|
||||
case "text_encoder_model":
|
||||
textEncoderModel = parts[1]
|
||||
textEncoderModel = value
|
||||
case "dit_model":
|
||||
ditModel = parts[1]
|
||||
ditModel = value
|
||||
case "vae_model":
|
||||
vaeModel = parts[1]
|
||||
vaeModel = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ type LLM struct {
|
||||
draftModel *llama.LLama
|
||||
}
|
||||
|
||||
|
||||
// Free releases GPU resources and frees the llama model
|
||||
// This should be called when the model is being unloaded to properly release VRAM
|
||||
func (llm *LLM) Free() error {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build debug
|
||||
// +build debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !debug
|
||||
// +build !debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot float32
|
||||
for i := 0; i < len(k1); i++ {
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
}
|
||||
|
||||
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := 0; i < len(k1); i++ {
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
|
||||
// to one-shot (only difference is resampler batch boundaries).
|
||||
var maxDiff float64
|
||||
var sumDiffSq float64
|
||||
for i := 0; i < minLen; i++ {
|
||||
for i := range minLen {
|
||||
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
|
||||
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
||||
|
||||
var persistentMaxDiff, freshMaxDiff float64
|
||||
for i := 0; i < minLen; i++ {
|
||||
for i := range minLen {
|
||||
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
||||
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
||||
if pd > persistentMaxDiff {
|
||||
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
|
||||
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
||||
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
||||
|
||||
Expect(stddev / mean).To(BeNumerically("<", 0.15),
|
||||
Expect(stddev/mean).To(BeNumerically("<", 0.15),
|
||||
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
||||
|
||||
// Also check frequency is correct
|
||||
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
|
||||
|
||||
// Every sample must be identical — the resampler is deterministic
|
||||
var maxDiff float64
|
||||
for i := 0; i < len(oneShot); i++ {
|
||||
for i := range len(oneShot) {
|
||||
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
|
||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
||||
copy(hdr[8:12], "WAVE")
|
||||
copy(hdr[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
copy(hdr[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
||||
|
||||
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
|
||||
)
|
||||
|
||||
pcm := make([]byte, toneNumSamples*2)
|
||||
for i := 0; i < toneNumSamples; i++ {
|
||||
for i := range toneNumSamples {
|
||||
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
||||
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||
|
||||
@@ -19,6 +19,10 @@ import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from acestep.inference import (
|
||||
GenerationParams,
|
||||
GenerationConfig,
|
||||
@@ -444,6 +448,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,6 +16,10 @@ import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
@@ -225,7 +229,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
78
backend/python/common/grpc_auth.py
Normal file
78
backend/python/common/grpc_auth.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
|
||||
|
||||
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
|
||||
a valid Bearer token in the 'authorization' metadata header are rejected with
|
||||
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
|
||||
performed (backward compatible).
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class _AbortHandler(grpc.RpcMethodHandler):
|
||||
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_streaming = False
|
||||
self.response_streaming = False
|
||||
self.request_deserializer = None
|
||||
self.response_serializer = None
|
||||
self.unary_unary = self._abort
|
||||
self.unary_stream = None
|
||||
self.stream_unary = None
|
||||
self.stream_stream = None
|
||||
|
||||
@staticmethod
|
||||
def _abort(request, context):
|
||||
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
|
||||
|
||||
|
||||
class TokenAuthInterceptor(grpc.ServerInterceptor):
|
||||
"""Sync gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
self._abort_handler = _AbortHandler()
|
||||
|
||||
def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return self._abort_handler
|
||||
return continuation(handler_call_details)
|
||||
|
||||
|
||||
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
|
||||
"""Async gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
|
||||
async def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return _AbortHandler()
|
||||
return await continuation(handler_call_details)
|
||||
|
||||
|
||||
def get_auth_interceptors(*, aio: bool = False):
|
||||
"""Return a list of gRPC interceptors for bearer token auth.
|
||||
|
||||
Args:
|
||||
aio: If True, return async-compatible interceptors for grpc.aio.server().
|
||||
If False (default), return sync interceptors for grpc.server().
|
||||
|
||||
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||
"""
|
||||
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||
if not token:
|
||||
return []
|
||||
if aio:
|
||||
return [AsyncTokenAuthInterceptor(token)]
|
||||
return [TokenAuthInterceptor(token)]
|
||||
@@ -15,6 +15,10 @@ import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -93,7 +97,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -22,6 +22,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
# Import dynamic loader for pipeline discovery
|
||||
from diffusers_dynamic_loader import (
|
||||
@@ -1042,7 +1046,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -15,6 +15,10 @@ import torch
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -165,6 +169,8 @@ def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
]
|
||||
,
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -70,7 +74,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -19,6 +19,10 @@ import numpy as np
|
||||
import json
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -424,6 +428,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,6 +16,10 @@ from kittentts import KittenTTS
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -77,7 +81,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,6 +16,10 @@ from kokoro import KPipeline
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -84,7 +88,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -17,6 +17,10 @@ import time
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
@@ -398,7 +402,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -15,6 +15,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_audio.tts.utils import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
@@ -436,7 +440,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -23,6 +23,10 @@ import tempfile
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -468,6 +472,8 @@ async def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -12,6 +12,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
@@ -446,7 +450,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -12,6 +12,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
@@ -421,7 +425,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -17,6 +17,10 @@ from moonshine_voice import (
|
||||
)
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -128,7 +132,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -119,7 +123,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -130,7 +134,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import outetts
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -116,7 +120,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
|
||||
@@ -16,6 +16,10 @@ import torch
|
||||
from pocket_tts import TTSModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -225,7 +229,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -184,7 +188,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -23,6 +23,10 @@ import hashlib
|
||||
import pickle
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -900,6 +904,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,6 +14,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
@@ -97,7 +101,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -13,6 +13,10 @@ import base64
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import requests
|
||||
|
||||
@@ -139,7 +143,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,6 +16,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
@@ -532,7 +536,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -17,6 +17,10 @@ import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
@@ -832,6 +836,8 @@ def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
|
||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -724,7 +728,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -27,6 +27,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
from vllm_omni.entrypoints.omni import Omni
|
||||
from vllm_omni.outputs import OmniRequestOutput
|
||||
@@ -650,7 +654,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -12,6 +12,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -338,7 +342,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -18,6 +18,10 @@ import backend_pb2_grpc
|
||||
import torch
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -297,7 +301,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -13,6 +13,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -137,7 +141,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -3,7 +3,7 @@ package application
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -22,13 +22,23 @@ func (a *Application) RestartAgentJobService() error {
|
||||
}
|
||||
|
||||
// Create new service instance
|
||||
agentJobService := services.NewAgentJobService(
|
||||
agentJobService := agentpool.NewAgentJobService(
|
||||
a.ApplicationConfig(),
|
||||
a.ModelLoader(),
|
||||
a.ModelConfigLoader(),
|
||||
a.TemplatesEvaluator(),
|
||||
)
|
||||
|
||||
// Re-apply distributed wiring if available (matches startup.go logic)
|
||||
if d := a.Distributed(); d != nil {
|
||||
if d.Dispatcher != nil {
|
||||
agentJobService.SetDistributedBackends(d.Dispatcher)
|
||||
}
|
||||
if d.JobStore != nil {
|
||||
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||
}
|
||||
}
|
||||
|
||||
// Start the service
|
||||
err := agentJobService.Start(a.ApplicationConfig().Context)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,12 +2,16 @@ package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -20,9 +24,9 @@ type Application struct {
|
||||
applicationConfig *config.ApplicationConfig
|
||||
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
|
||||
templatesEvaluator *templates.Evaluator
|
||||
galleryService *services.GalleryService
|
||||
agentJobService *services.AgentJobService
|
||||
agentPoolService atomic.Pointer[services.AgentPoolService]
|
||||
galleryService *galleryop.GalleryService
|
||||
agentJobService *agentpool.AgentJobService
|
||||
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
|
||||
authDB *gorm.DB
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
@@ -30,6 +34,9 @@ type Application struct {
|
||||
p2pCtx context.Context
|
||||
p2pCancel context.CancelFunc
|
||||
agentJobMutex sync.Mutex
|
||||
|
||||
// Distributed mode services (nil when not in distributed mode)
|
||||
distributed *DistributedServices
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -64,15 +71,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||
return a.templatesEvaluator
|
||||
}
|
||||
|
||||
func (a *Application) GalleryService() *services.GalleryService {
|
||||
func (a *Application) GalleryService() *galleryop.GalleryService {
|
||||
return a.galleryService
|
||||
}
|
||||
|
||||
func (a *Application) AgentJobService() *services.AgentJobService {
|
||||
func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
||||
return a.agentJobService
|
||||
}
|
||||
|
||||
func (a *Application) AgentPoolService() *services.AgentPoolService {
|
||||
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
@@ -86,8 +93,53 @@ func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
}
|
||||
|
||||
// Distributed returns the distributed services, or nil if not in distributed mode.
|
||||
func (a *Application) Distributed() *DistributedServices {
|
||||
return a.distributed
|
||||
}
|
||||
|
||||
// IsDistributed returns true if the application is running in distributed mode.
|
||||
func (a *Application) IsDistributed() bool {
|
||||
return a.distributed != nil
|
||||
}
|
||||
|
||||
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
|
||||
// This prevents the agent pool from failing during startup when workers haven't connected yet.
|
||||
func (a *Application) waitForHealthyWorker() {
|
||||
maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault()
|
||||
const basePoll = 2 * time.Second
|
||||
|
||||
xlog.Info("Waiting for at least one healthy backend worker before starting agent pool")
|
||||
deadline := time.Now().Add(maxWait)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
registered, err := a.distributed.Registry.List(context.Background())
|
||||
if err == nil {
|
||||
for _, n := range registered {
|
||||
if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy {
|
||||
xlog.Info("Healthy backend worker found", "node", n.Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add 0-1s jitter to prevent thundering-herd on the node registry
|
||||
jitter := time.Duration(rand.Int64N(int64(time.Second)))
|
||||
select {
|
||||
case <-a.applicationConfig.Context.Done():
|
||||
return
|
||||
case <-time.After(basePoll + jitter):
|
||||
}
|
||||
}
|
||||
xlog.Warn("No healthy backend worker found after waiting, proceeding anyway")
|
||||
}
|
||||
|
||||
// InstanceID returns the unique identifier for this frontend instance.
|
||||
func (a *Application) InstanceID() string {
|
||||
return a.applicationConfig.Distributed.InstanceID
|
||||
}
|
||||
|
||||
func (a *Application) start() error {
|
||||
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||
galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -95,19 +147,14 @@ func (a *Application) start() error {
|
||||
|
||||
a.galleryService = galleryService
|
||||
|
||||
// Initialize agent job service
|
||||
agentJobService := services.NewAgentJobService(
|
||||
// Initialize agent job service (Start() is deferred to after distributed wiring)
|
||||
agentJobService := agentpool.NewAgentJobService(
|
||||
a.ApplicationConfig(),
|
||||
a.ModelLoader(),
|
||||
a.ModelConfigLoader(),
|
||||
a.TemplatesEvaluator(),
|
||||
)
|
||||
|
||||
err = agentJobService.Start(a.ApplicationConfig().Context)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.agentJobService = agentJobService
|
||||
|
||||
return nil
|
||||
@@ -120,27 +167,56 @@ func (a *Application) StartAgentPool() {
|
||||
if !a.applicationConfig.AgentPool.Enabled {
|
||||
return
|
||||
}
|
||||
aps, err := services.NewAgentPoolService(a.applicationConfig)
|
||||
// Build options struct from available dependencies
|
||||
opts := agentpool.AgentPoolOptions{
|
||||
AuthDB: a.authDB,
|
||||
}
|
||||
if d := a.Distributed(); d != nil {
|
||||
if d.DistStores != nil && d.DistStores.Skills != nil {
|
||||
opts.SkillStore = d.DistStores.Skills
|
||||
}
|
||||
opts.NATSClient = d.Nats
|
||||
opts.EventBridge = d.AgentBridge
|
||||
opts.AgentStore = d.AgentStore
|
||||
}
|
||||
|
||||
aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to create agent pool service", "error", err)
|
||||
return
|
||||
}
|
||||
if a.authDB != nil {
|
||||
aps.SetAuthDB(a.authDB)
|
||||
|
||||
// Wire distributed mode components
|
||||
if d := a.Distributed(); d != nil {
|
||||
// Wait for at least one healthy backend worker before starting the agent pool.
|
||||
// Collections initialization calls embeddings which require a worker.
|
||||
if d.Registry != nil {
|
||||
a.waitForHealthyWorker()
|
||||
}
|
||||
}
|
||||
|
||||
if err := aps.Start(a.applicationConfig.Context); err != nil {
|
||||
xlog.Error("Failed to start agent pool", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
|
||||
usm := services.NewUserServicesManager(
|
||||
usm := agentpool.NewUserServicesManager(
|
||||
aps.UserStorage(),
|
||||
a.applicationConfig,
|
||||
a.modelLoader,
|
||||
a.backendLoader,
|
||||
a.templatesEvaluator,
|
||||
)
|
||||
// Wire distributed backends to per-user job services
|
||||
if a.agentJobService != nil {
|
||||
if d := a.agentJobService.Dispatcher(); d != nil {
|
||||
usm.SetJobDispatcher(d)
|
||||
}
|
||||
if s := a.agentJobService.DBStore(); s != nil {
|
||||
usm.SetJobDBStore(s)
|
||||
}
|
||||
}
|
||||
aps.SetUserServicesManager(usm)
|
||||
|
||||
a.agentPoolService.Store(aps)
|
||||
|
||||
267
core/application/distributed.go
Normal file
267
core/application/distributed.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/agents"
|
||||
"github.com/mudler/LocalAI/core/services/distributed"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DistributedServices holds all services initialized for distributed mode.
|
||||
type DistributedServices struct {
|
||||
Nats *messaging.Client
|
||||
Store storage.ObjectStore
|
||||
Registry *nodes.NodeRegistry
|
||||
Router *nodes.SmartRouter
|
||||
Health *nodes.HealthMonitor
|
||||
JobStore *jobs.JobStore
|
||||
Dispatcher *jobs.Dispatcher
|
||||
AgentStore *agents.AgentStore
|
||||
AgentBridge *agents.EventBridge
|
||||
DistStores *distributed.Stores
|
||||
FileMgr *storage.FileManager
|
||||
FileStager nodes.FileStager
|
||||
ModelAdapter *nodes.ModelRouterAdapter
|
||||
Unloader *nodes.RemoteUnloaderAdapter
|
||||
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// Shutdown stops all distributed services in reverse initialization order.
|
||||
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
|
||||
func (ds *DistributedServices) Shutdown() {
|
||||
if ds == nil {
|
||||
return
|
||||
}
|
||||
ds.shutdownOnce.Do(func() {
|
||||
if ds.Health != nil {
|
||||
ds.Health.Stop()
|
||||
}
|
||||
if ds.Dispatcher != nil {
|
||||
ds.Dispatcher.Stop()
|
||||
}
|
||||
if closer, ok := ds.Store.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
|
||||
// when the NATS client is closed below.
|
||||
if ds.Nats != nil {
|
||||
ds.Nats.Close()
|
||||
}
|
||||
xlog.Info("Distributed services shut down")
|
||||
})
|
||||
}
|
||||
|
||||
// initDistributed validates distributed mode prerequisites and initializes
|
||||
// NATS, object storage, node registry, and instance identity.
|
||||
// Returns nil if distributed mode is not enabled.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
|
||||
if !cfg.Distributed.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
xlog.Info("Distributed mode enabled — validating prerequisites")
|
||||
|
||||
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
|
||||
if err := cfg.Distributed.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
|
||||
if !cfg.Auth.Enabled {
|
||||
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
|
||||
}
|
||||
if !isPostgresURL(cfg.Auth.DatabaseURL) {
|
||||
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
|
||||
}
|
||||
|
||||
// Generate instance ID if not set
|
||||
if cfg.Distributed.InstanceID == "" {
|
||||
cfg.Distributed.InstanceID = uuid.New().String()
|
||||
}
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
|
||||
|
||||
// Ensure NATS is closed if any subsequent initialization step fails.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
natsClient.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize object storage
|
||||
var store storage.ObjectStore
|
||||
if cfg.Distributed.StorageURL != "" {
|
||||
if cfg.Distributed.StorageBucket == "" {
|
||||
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
|
||||
}
|
||||
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
|
||||
Endpoint: cfg.Distributed.StorageURL,
|
||||
Region: cfg.Distributed.StorageRegion,
|
||||
Bucket: cfg.Distributed.StorageBucket,
|
||||
AccessKeyID: cfg.Distributed.StorageAccessKey,
|
||||
SecretAccessKey: cfg.Distributed.StorageSecretKey,
|
||||
ForcePathStyle: true, // required for MinIO
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing S3 storage: %w", err)
|
||||
}
|
||||
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
|
||||
store = s3Store
|
||||
} else {
|
||||
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
|
||||
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
|
||||
}
|
||||
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
|
||||
store = fsStore
|
||||
}
|
||||
|
||||
// Initialize node registry (requires the auth DB which is PostgreSQL)
|
||||
if authDB == nil {
|
||||
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
|
||||
}
|
||||
|
||||
registry, err := nodes.NewNodeRegistry(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing node registry: %w", err)
|
||||
}
|
||||
xlog.Info("Node registry initialized")
|
||||
|
||||
// Collect SmartRouter option values; the router itself is created after all
|
||||
// dependencies (including FileStager and Unloader) are ready.
|
||||
var routerAuthToken string
|
||||
if cfg.Distributed.RegistrationToken != "" {
|
||||
routerAuthToken = cfg.Distributed.RegistrationToken
|
||||
}
|
||||
var routerGalleriesJSON string
|
||||
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
|
||||
routerGalleriesJSON = string(galleriesJSON)
|
||||
}
|
||||
|
||||
healthMon := nodes.NewHealthMonitor(registry, authDB,
|
||||
cfg.Distributed.HealthCheckIntervalOrDefault(),
|
||||
cfg.Distributed.StaleNodeThresholdOrDefault(),
|
||||
routerAuthToken,
|
||||
cfg.Distributed.PerModelHealthCheck,
|
||||
)
|
||||
|
||||
// Initialize job store
|
||||
jobStore, err := jobs.NewJobStore(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing job store: %w", err)
|
||||
}
|
||||
xlog.Info("Distributed job store initialized")
|
||||
|
||||
// Initialize job dispatcher
|
||||
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
|
||||
|
||||
// Initialize agent store
|
||||
agentStore, err := agents.NewAgentStore(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing agent store: %w", err)
|
||||
}
|
||||
xlog.Info("Distributed agent store initialized")
|
||||
|
||||
// Initialize agent event bridge
|
||||
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
|
||||
|
||||
// Start observable persister — captures observable_update events from workers
|
||||
// (which have no DB access) and persists them to PostgreSQL.
|
||||
if err := agentBridge.StartObservablePersister(); err != nil {
|
||||
xlog.Warn("Failed to start observable persister", "error", err)
|
||||
} else {
|
||||
xlog.Info("Observable persister started")
|
||||
}
|
||||
|
||||
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
|
||||
distStores, err := distributed.InitStores(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing distributed stores: %w", err)
|
||||
}
|
||||
|
||||
// Initialize file manager with local cache
|
||||
cacheDir := cfg.DataPath + "/cache"
|
||||
fileMgr, err := storage.NewFileManager(store, cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing file manager: %w", err)
|
||||
}
|
||||
xlog.Info("File manager initialized", "cacheDir", cacheDir)
|
||||
|
||||
// Create FileStager for distributed file transfer
|
||||
var fileStager nodes.FileStager
|
||||
if cfg.Distributed.StorageURL != "" {
|
||||
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
|
||||
xlog.Info("File stager initialized (S3+NATS)")
|
||||
} else {
|
||||
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
node, err := registry.Get(context.Background(), nodeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if node.HTTPAddress == "" {
|
||||
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
|
||||
}
|
||||
return node.HTTPAddress, nil
|
||||
}, cfg.Distributed.RegistrationToken)
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
modelAdapter := nodes.NewModelRouterAdapter(router)
|
||||
|
||||
success = true
|
||||
return &DistributedServices{
|
||||
Nats: natsClient,
|
||||
Store: store,
|
||||
Registry: registry,
|
||||
Router: router,
|
||||
Health: healthMon,
|
||||
JobStore: jobStore,
|
||||
Dispatcher: dispatcher,
|
||||
AgentStore: agentStore,
|
||||
AgentBridge: agentBridge,
|
||||
DistStores: distStores,
|
||||
FileMgr: fileMgr,
|
||||
FileStager: fileStager,
|
||||
ModelAdapter: modelAdapter,
|
||||
Unloader: remoteUnloader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isPostgresURL(url string) bool {
|
||||
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
|
||||
"github.com/mudler/edgevpn/pkg/node"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -146,22 +146,14 @@ func (a *Application) RestartP2P() error {
|
||||
return fmt.Errorf("P2P token is not set")
|
||||
}
|
||||
|
||||
// Create new context for P2P
|
||||
ctx, cancel := context.WithCancel(appConfig.Context)
|
||||
a.p2pCtx = ctx
|
||||
a.p2pCancel = cancel
|
||||
|
||||
// Get API address from config
|
||||
address := appConfig.APIAddress
|
||||
if address == "" {
|
||||
address = "127.0.0.1:8080" // default
|
||||
}
|
||||
|
||||
// Start P2P stack in a goroutine
|
||||
// Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel
|
||||
go func() {
|
||||
if err := a.StartP2P(); err != nil {
|
||||
xlog.Error("Failed to start P2P stack", "error", err)
|
||||
cancel() // Cancel context on error
|
||||
if a.p2pCancel != nil {
|
||||
a.p2pCancel()
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("P2P stack restarted with new settings")
|
||||
@@ -228,7 +220,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error {
|
||||
continue
|
||||
}
|
||||
|
||||
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: model,
|
||||
Galleries: app.ApplicationConfig().Galleries,
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -101,7 +105,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
|
||||
}
|
||||
application.authDB = authDB
|
||||
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
|
||||
xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL))
|
||||
|
||||
// Start session and expired API key cleanup goroutine
|
||||
go func() {
|
||||
@@ -123,12 +127,92 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
dbJobStore, err := jobs.NewJobStore(application.authDB)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to create job store for auth DB", "error", err)
|
||||
} else {
|
||||
application.agentJobService.SetDistributedJobStore(dbJobStore)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize distributed mode services (NATS, object storage, node registry)
|
||||
distSvc, err := initDistributed(options, application.authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
|
||||
}
|
||||
if distSvc != nil {
|
||||
application.distributed = distSvc
|
||||
// Wire remote model unloader so ShutdownModel works for remote nodes
|
||||
// Uses NATS to tell serve-backend nodes to Free + kill their backend process
|
||||
application.modelLoader.SetRemoteUnloader(distSvc.Unloader)
|
||||
// Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode
|
||||
application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter())
|
||||
// Wire DistributedModelStore so shutdown/list/watchdog can find remote models
|
||||
distStore := nodes.NewDistributedModelStore(
|
||||
model.NewInMemoryModelStore(),
|
||||
distSvc.Registry,
|
||||
)
|
||||
application.modelLoader.SetModelStore(distStore)
|
||||
// Start health monitor
|
||||
distSvc.Health.Start(options.Context)
|
||||
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
|
||||
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
|
||||
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
|
||||
// but does NOT set a workerFn — agent workers consume jobs from the same NATS queue.
|
||||
|
||||
// Wire model config loader so job events include model config for agent workers
|
||||
distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader)
|
||||
|
||||
// Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched
|
||||
if err := distSvc.Dispatcher.Start(options.Context); err != nil {
|
||||
return nil, fmt.Errorf("starting job dispatcher: %w", err)
|
||||
}
|
||||
// Start ephemeral file cleanup
|
||||
storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0)
|
||||
// Wire distributed backends into AgentJobService (before Start)
|
||||
if application.agentJobService != nil {
|
||||
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||
}
|
||||
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||
|
||||
// Wire NATS and gallery store into GalleryService for cross-instance progress/cancel
|
||||
if application.galleryService != nil {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Start AgentJobService (after distributed wiring so it knows whether to use local or NATS)
|
||||
if application.agentJobService != nil {
|
||||
if err := application.agentJobService.Start(options.Context); err != nil {
|
||||
return nil, fmt.Errorf("starting agent job service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
xlog.Error("error installing models", "error", err)
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -154,13 +238,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -184,6 +268,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
xlog.Debug("Context canceled, shutting down")
|
||||
application.distributed.Shutdown()
|
||||
err := application.ModelLoader().StopAllGRPC()
|
||||
if err != nil {
|
||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||
@@ -207,7 +292,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
var backendErr error
|
||||
_, backendErr = application.ModelLoader().Load(o...)
|
||||
if backendErr != nil {
|
||||
return nil, err
|
||||
return nil, backendErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -27,7 +27,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
|
||||
}
|
||||
|
||||
@@ -47,14 +47,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
|
||||
modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
modelName := c.Name
|
||||
if modelName == "" {
|
||||
modelName = c.Model
|
||||
}
|
||||
if !slices.Contains(modelNames, modelName) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||
//return nil, err
|
||||
@@ -252,12 +256,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"function_template": c.TemplateConfig.Functions,
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -86,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 {
|
||||
}
|
||||
|
||||
if seed == config.RAND_SEED {
|
||||
seed = rand.Int31()
|
||||
seed = rand.Int32()
|
||||
}
|
||||
|
||||
return seed
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VAD(request *schema.VADRequest,
|
||||
ctx context.Context,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := proto.VADRequest{
|
||||
Audio: request.Audio,
|
||||
}
|
||||
resp, err := vadModel.VAD(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segments := []schema.VADSegment{}
|
||||
for _, s := range resp.Segments {
|
||||
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
|
||||
}
|
||||
|
||||
return &schema.VADResponse{
|
||||
Segments: segments,
|
||||
}, nil
|
||||
}
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VAD(request *schema.VADRequest,
|
||||
ctx context.Context,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := proto.VADRequest{
|
||||
Audio: request.Audio,
|
||||
}
|
||||
resp, err := vadModel.VAD(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segments := []schema.VADSegment{}
|
||||
for _, s := range resp.Segments {
|
||||
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
|
||||
}
|
||||
|
||||
return &schema.VADResponse{
|
||||
Segments: segments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
appConfig := r.buildAppConfig()
|
||||
|
||||
poolService, err := services.NewAgentPoolService(appConfig)
|
||||
poolService, err := agentpool.NewAgentPoolService(appConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent pool service: %w", err)
|
||||
}
|
||||
|
||||
463
core/cli/agent_worker.go
Normal file
463
core/cli/agent_worker.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agents"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
mcpRemote "github.com/mudler/LocalAI/core/services/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/mudler/cogito/clients"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// AgentWorkerCMD starts a dedicated agent worker process for distributed mode.
|
||||
// It registers with the frontend, subscribes to the NATS agent execution queue,
|
||||
// and executes agent chats using cogito. The worker is a pure executor — it
|
||||
// receives the full agent config and skills in the NATS job payload, so it
|
||||
// does not need direct database access.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// localai agent-worker --nats-url nats://... --register-to http://localai:8080
|
||||
type AgentWorkerCMD struct {
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
|
||||
// Registration (required)
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
|
||||
// API access
|
||||
APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"`
|
||||
APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"`
|
||||
|
||||
// NATS subjects
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
|
||||
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||
|
||||
// Resolve API URL
|
||||
apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/"))
|
||||
|
||||
// Register with frontend
|
||||
regClient := &workerregistry.RegistrationClient{
|
||||
FrontendURL: cmd.RegisterTo,
|
||||
RegistrationToken: cmd.RegistrationToken,
|
||||
}
|
||||
|
||||
nodeName := cmd.NodeName
|
||||
if nodeName == "" {
|
||||
hostname, _ := os.Hostname()
|
||||
nodeName = "agent-" + hostname
|
||||
}
|
||||
registrationBody := map[string]any{
|
||||
"name": nodeName,
|
||||
"node_type": "agent",
|
||||
}
|
||||
if cmd.RegistrationToken != "" {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
|
||||
if err != nil && cmd.HeartbeatInterval != "" {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
defer natsClient.Close()
|
||||
|
||||
// Create event bridge for publishing results back via NATS
|
||||
eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID)
|
||||
|
||||
// Start cancel listener
|
||||
cancelSub, err := eventBridge.StartCancelListener()
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to start cancel listener", "error", err)
|
||||
} else {
|
||||
defer cancelSub.Unsubscribe()
|
||||
}
|
||||
|
||||
// Create and start the NATS dispatcher.
|
||||
// No ConfigProvider or SkillStore needed — config and skills arrive in the job payload.
|
||||
dispatcher := agents.NewNATSDispatcher(
|
||||
natsClient,
|
||||
eventBridge,
|
||||
nil, // no ConfigProvider: config comes in the enriched NATS payload
|
||||
apiURL, cmd.APIToken,
|
||||
cmd.Subject, cmd.Queue,
|
||||
0, // no concurrency limit (CLI worker)
|
||||
)
|
||||
|
||||
if err := dispatcher.Start(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("starting dispatcher: %w", err)
|
||||
}
|
||||
|
||||
// Subscribe to MCP tool execution requests (load-balanced across workers).
|
||||
// The frontend routes model-level MCP tool calls here via NATS request-reply.
|
||||
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
||||
handleMCPToolRequest(data, reply)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err)
|
||||
}
|
||||
|
||||
// Subscribe to MCP discovery requests (load-balanced across workers).
|
||||
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
||||
handleMCPDiscoveryRequest(data, reply)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err)
|
||||
}
|
||||
|
||||
// Subscribe to MCP CI job execution (load-balanced across agent workers).
|
||||
// In distributed mode, MCP CI jobs are routed here because the frontend
|
||||
// cannot create MCP sessions (e.g., stdio servers using docker).
|
||||
mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout)
|
||||
if err != nil && cmd.MCPCIJobTimeout != "" {
|
||||
xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err)
|
||||
}
|
||||
mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout)
|
||||
|
||||
if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) {
|
||||
handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err)
|
||||
}
|
||||
|
||||
// Subscribe to backend stop events to clean up cached MCP sessions.
|
||||
// In the main application this is done via ml.OnModelUnload, but the agent
|
||||
// worker has no model loader — we listen for the NATS stop event instead.
|
||||
if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) {
|
||||
var req struct {
|
||||
Backend string `json:"backend"`
|
||||
}
|
||||
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
|
||||
mcpTools.CloseMCPSessions(req.Backend)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err)
|
||||
}
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
// The worker creates/caches MCP sessions from the serialized config and executes the tool.
|
||||
func handleMCPToolRequest(data []byte, reply func([]byte)) {
|
||||
var req mcpRemote.MCPToolRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create/cache named MCP sessions from the provided config
|
||||
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
||||
if err != nil {
|
||||
sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Discover tools to find the right session
|
||||
tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
||||
if err != nil {
|
||||
sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
argsJSON, _ := json.Marshal(req.Arguments)
|
||||
result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON))
|
||||
if err != nil {
|
||||
sendMCPToolReply(reply, "", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sendMCPToolReply(reply, result, "")
|
||||
}
|
||||
|
||||
func sendMCPToolReply(reply func([]byte), result, errMsg string) {
|
||||
resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg}
|
||||
data, _ := json.Marshal(resp)
|
||||
reply(data)
|
||||
}
|
||||
|
||||
// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery.
|
||||
func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) {
|
||||
var req mcpRemote.MCPDiscoveryRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create/cache named MCP sessions
|
||||
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
||||
if err != nil {
|
||||
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// List servers with their tools/prompts/resources
|
||||
serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions)
|
||||
if err != nil {
|
||||
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Also get tool function schemas for the frontend
|
||||
tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
||||
var toolDefs []mcpRemote.MCPToolDef
|
||||
for _, t := range tools {
|
||||
toolDefs = append(toolDefs, mcpRemote.MCPToolDef{
|
||||
ServerName: t.ServerName,
|
||||
ToolName: t.ToolName,
|
||||
Function: t.Function,
|
||||
})
|
||||
}
|
||||
|
||||
// Convert server infos
|
||||
var servers []mcpRemote.MCPServerInfo
|
||||
for _, s := range serverInfos {
|
||||
servers = append(servers, mcpRemote.MCPServerInfo{
|
||||
Name: s.Name,
|
||||
Type: s.Type,
|
||||
Tools: s.Tools,
|
||||
Prompts: s.Prompts,
|
||||
Resources: s.Resources,
|
||||
})
|
||||
}
|
||||
|
||||
sendMCPDiscoveryReply(reply, servers, toolDefs, "")
|
||||
}
|
||||
|
||||
func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) {
|
||||
resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg}
|
||||
data, _ := json.Marshal(resp)
|
||||
reply(data)
|
||||
}
|
||||
|
||||
// handleMCPCIJob processes an MCP CI job on the agent worker.
|
||||
// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference.
|
||||
func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) {
|
||||
var evt jobs.JobEvent
|
||||
if err := json.Unmarshal(data, &evt); err != nil {
|
||||
xlog.Error("Failed to unmarshal job event", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
job := evt.Job
|
||||
task := evt.Task
|
||||
if job == nil || task == nil {
|
||||
xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID)
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event")
|
||||
return
|
||||
}
|
||||
|
||||
modelCfg := evt.ModelConfig
|
||||
if modelCfg == nil {
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event")
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model)
|
||||
|
||||
// Publish running status
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, Status: "running", Message: "Job started on agent worker",
|
||||
})
|
||||
|
||||
// Parse MCP config
|
||||
if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" {
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model")
|
||||
return
|
||||
}
|
||||
|
||||
remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML()
|
||||
if err != nil {
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Create MCP sessions locally (agent worker has docker)
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio)
|
||||
if err != nil || len(sessions) == 0 {
|
||||
errMsg := "no working MCP servers found"
|
||||
if err != nil {
|
||||
errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err)
|
||||
}
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
// Build prompt from template
|
||||
prompt := task.Prompt
|
||||
if task.CronParametersJSON != "" {
|
||||
var params map[string]string
|
||||
if err := json.Unmarshal([]byte(task.CronParametersJSON), ¶ms); err != nil {
|
||||
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
||||
}
|
||||
for k, v := range params {
|
||||
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
||||
}
|
||||
}
|
||||
if job.ParametersJSON != "" {
|
||||
var params map[string]string
|
||||
if err := json.Unmarshal([]byte(job.ParametersJSON), ¶ms); err != nil {
|
||||
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
||||
}
|
||||
for k, v := range params {
|
||||
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Create LLM client pointing back to the frontend API
|
||||
llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL)
|
||||
|
||||
// Build cogito options
|
||||
ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Update job status to running in DB
|
||||
publishJobStatus(natsClient, evt.JobID, "running", "")
|
||||
|
||||
// Buffer stream tokens and flush as complete blocks
|
||||
var reasoningBuf, contentBuf strings.Builder
|
||||
var lastStreamType cogito.StreamEventType
|
||||
|
||||
flushStreamBuf := func() {
|
||||
if reasoningBuf.Len() > 0 {
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(),
|
||||
})
|
||||
reasoningBuf.Reset()
|
||||
}
|
||||
if contentBuf.Len() > 0 {
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(),
|
||||
})
|
||||
contentBuf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
cogitoOpts := modelCfg.BuildCogitoOptions()
|
||||
cogitoOpts = append(cogitoOpts,
|
||||
cogito.WithContext(ctx),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithStatusCallback(func(status string) {
|
||||
flushStreamBuf()
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, TraceType: "status", TraceContent: status,
|
||||
})
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
flushStreamBuf()
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result),
|
||||
})
|
||||
}),
|
||||
cogito.WithStreamCallback(func(ev cogito.StreamEvent) {
|
||||
// Flush if stream type changed (e.g., reasoning → content)
|
||||
if ev.Type != lastStreamType {
|
||||
flushStreamBuf()
|
||||
lastStreamType = ev.Type
|
||||
}
|
||||
switch ev.Type {
|
||||
case cogito.StreamEventReasoning:
|
||||
reasoningBuf.WriteString(ev.Content)
|
||||
case cogito.StreamEventContent:
|
||||
contentBuf.WriteString(ev.Content)
|
||||
case cogito.StreamEventToolCall:
|
||||
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||
JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs),
|
||||
})
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
// Execute via cogito
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
fragment = fragment.AddMessage("user", prompt)
|
||||
|
||||
f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...)
|
||||
flushStreamBuf() // flush any remaining buffered tokens
|
||||
|
||||
if err != nil {
|
||||
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
result := ""
|
||||
if msg := f.LastMessage(); msg != nil {
|
||||
result = msg.Content
|
||||
}
|
||||
publishJobResult(natsClient, evt.JobID, "completed", result, "")
|
||||
xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result))
|
||||
}
|
||||
|
||||
func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) {
|
||||
jobs.PublishJobProgress(nc, jobID, status, message)
|
||||
}
|
||||
|
||||
func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) {
|
||||
jobs.PublishJobResult(nc, jobID, status, result, errMsg)
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
@@ -103,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -15,7 +15,9 @@ var CLI struct {
|
||||
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
||||
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
|
||||
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
||||
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
||||
P2PWorker worker.Worker `cmd:"" name:"p2p-worker" help:"Run workers to distribute workload via p2p (llama.cpp-only)"`
|
||||
Worker WorkerCMD `cmd:"" help:"Start a worker for distributed mode (generic, backend-agnostic)"`
|
||||
AgentWorker AgentWorkerCMD `cmd:"" name:"agent-worker" help:"Start an agent worker for distributed mode (executes agent chats via NATS)"`
|
||||
Util UtilCMD `cmd:"" help:"Utility commands"`
|
||||
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
|
||||
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
||||
|
||||
@@ -186,9 +186,9 @@ _local_ai_completions()
|
||||
}
|
||||
subcmds := []string{}
|
||||
for _, sub := range cmds {
|
||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
||||
subcmds = append(subcmds, parts[1])
|
||||
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||
subcmds = append(subcmds, child)
|
||||
}
|
||||
}
|
||||
if len(subcmds) > 0 {
|
||||
@@ -279,8 +279,8 @@ _local_ai() {
|
||||
// Check for subcommands
|
||||
subcmds := []commandInfo{}
|
||||
for _, sub := range cmds {
|
||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
||||
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||
subcmds = append(subcmds, sub)
|
||||
}
|
||||
}
|
||||
@@ -289,11 +289,11 @@ _local_ai() {
|
||||
sb.WriteString(" local -a subcmds\n")
|
||||
sb.WriteString(" subcmds=(\n")
|
||||
for _, sub := range subcmds {
|
||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
||||
_, child, _ := strings.Cut(sub.fullName, " ")
|
||||
help := strings.ReplaceAll(sub.help, "'", "'\\''")
|
||||
help = strings.ReplaceAll(help, "[", "\\[")
|
||||
help = strings.ReplaceAll(help, "]", "\\]")
|
||||
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help))
|
||||
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", child, help))
|
||||
}
|
||||
sb.WriteString(" )\n")
|
||||
sb.WriteString(" _describe -t commands 'subcommands' subcmds\n")
|
||||
@@ -372,10 +372,10 @@ func generateFishCompletion(app *kong.Application) string {
|
||||
|
||||
// Subcommands
|
||||
for _, sub := range cmds {
|
||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
||||
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||
help := strings.ReplaceAll(sub.help, "'", "\\'")
|
||||
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help))
|
||||
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, child, help))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
func getTestApp() *kong.Application {
|
||||
var testCLI struct {
|
||||
Run struct{} `cmd:"" help:"Run the server"`
|
||||
Models struct {
|
||||
Run struct{} `cmd:"" help:"Run the server"`
|
||||
Models struct {
|
||||
List struct{} `cmd:"" help:"List models"`
|
||||
Install struct{} `cmd:"" help:"Install a model"`
|
||||
} `cmd:"" help:"Manage models"`
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
@@ -80,7 +80,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
galleryService := services.NewGalleryService(&config.ApplicationConfig{
|
||||
galleryService := galleryop.NewGalleryService(&config.ApplicationConfig{
|
||||
SystemState: systemState,
|
||||
}, model.NewModelLoader(systemState))
|
||||
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
|
||||
|
||||
@@ -44,9 +44,9 @@ type RunCMD struct {
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
@@ -100,7 +100,7 @@ type RunCMD struct {
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
// Agent Pool (LocalAGI)
|
||||
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
|
||||
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
|
||||
AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"`
|
||||
AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"`
|
||||
AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"`
|
||||
@@ -109,17 +109,17 @@ type RunCMD struct {
|
||||
AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"`
|
||||
AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"`
|
||||
AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"`
|
||||
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
|
||||
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
|
||||
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
|
||||
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
|
||||
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
|
||||
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
|
||||
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
|
||||
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
|
||||
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
|
||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
|
||||
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
|
||||
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
|
||||
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
|
||||
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
|
||||
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
|
||||
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
|
||||
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
|
||||
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
|
||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||
|
||||
// Authentication
|
||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||
@@ -136,6 +136,18 @@ type RunCMD struct {
|
||||
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
|
||||
@@ -210,6 +222,38 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}),
|
||||
}
|
||||
|
||||
// Distributed mode
|
||||
if r.Distributed {
|
||||
opts = append(opts, config.EnableDistributed)
|
||||
}
|
||||
if r.InstanceID != "" {
|
||||
opts = append(opts, config.WithDistributedInstanceID(r.InstanceID))
|
||||
}
|
||||
if r.NatsURL != "" {
|
||||
opts = append(opts, config.WithNatsURL(r.NatsURL))
|
||||
}
|
||||
if r.StorageURL != "" {
|
||||
opts = append(opts, config.WithStorageURL(r.StorageURL))
|
||||
}
|
||||
if r.StorageBucket != "" {
|
||||
opts = append(opts, config.WithStorageBucket(r.StorageBucket))
|
||||
}
|
||||
if r.StorageRegion != "" {
|
||||
opts = append(opts, config.WithStorageRegion(r.StorageRegion))
|
||||
}
|
||||
if r.StorageAccessKey != "" {
|
||||
opts = append(opts, config.WithStorageAccessKey(r.StorageAccessKey))
|
||||
}
|
||||
if r.StorageSecretKey != "" {
|
||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
}
|
||||
@@ -218,10 +262,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.DisableRuntimeSettings)
|
||||
}
|
||||
|
||||
if r.EnableTracing {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
|
||||
if r.EnableTracing {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
@@ -479,6 +519,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||
}
|
||||
// Clean up distributed services (idempotent — safe if already called)
|
||||
if d := app.Distributed(); d != nil {
|
||||
d.Shutdown()
|
||||
}
|
||||
})
|
||||
|
||||
// Start the agent pool after the HTTP server is listening, because
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/format"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -80,7 +79,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
switch t.ResponseFormat {
|
||||
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
|
||||
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
|
||||
fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat))
|
||||
case schema.TranscriptionResponseFormatJson:
|
||||
tr.Segments = nil
|
||||
fallthrough
|
||||
|
||||
897
core/cli/worker.go
Normal file
897
core/cli/worker.go
Normal file
@@ -0,0 +1,897 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// isPathAllowed checks if path is within one of the allowed directories.
|
||||
func isPathAllowed(path string, allowedDirs []string) bool {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
resolved, err := filepath.EvalSymlinks(absPath)
|
||||
if err != nil {
|
||||
// Path may not exist yet; use the absolute path
|
||||
resolved = absPath
|
||||
}
|
||||
for _, dir := range allowedDirs {
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WorkerCMD starts a generic worker process for distributed mode.
|
||||
// Workers are backend-agnostic — they wait for backend.install NATS events
|
||||
// from the SmartRouter to install and start the required backend.
|
||||
//
|
||||
// NATS is required. The worker acts as a process supervisor:
|
||||
// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success
|
||||
// - Receives backend.stop → stops the gRPC process
|
||||
// - Receives stop → full shutdown (deregister + exit)
|
||||
//
|
||||
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
|
||||
type WorkerCMD struct {
|
||||
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
|
||||
|
||||
// HTTP file transfer
|
||||
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
|
||||
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
|
||||
|
||||
// Registration (required)
|
||||
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
|
||||
// S3 storage for distributed file transfer
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
|
||||
}
|
||||
|
||||
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting worker", "addr", cmd.Addr)
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(cmd.ModelsPath),
|
||||
system.WithBackendPath(cmd.BackendsPath),
|
||||
system.WithBackendSystemPath(cmd.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting system state: %w", err)
|
||||
}
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
ml.SetBackendLoggingEnabled(true)
|
||||
|
||||
// Register already-installed backends
|
||||
gallery.RegisterBackends(systemState, ml)
|
||||
|
||||
// Parse galleries config
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil {
|
||||
xlog.Warn("Failed to parse backend galleries", "error", err)
|
||||
}
|
||||
|
||||
// Self-registration with frontend (with retry)
|
||||
regClient := &workerregistry.RegistrationClient{
|
||||
FrontendURL: cmd.RegisterTo,
|
||||
RegistrationToken: cmd.RegistrationToken,
|
||||
}
|
||||
|
||||
registrationBody := cmd.registrationBody()
|
||||
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", err)
|
||||
}
|
||||
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
|
||||
if err != nil && cmd.HeartbeatInterval != "" {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Start HTTP file transfer server
|
||||
httpAddr := cmd.resolveHTTPAddr()
|
||||
stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging")
|
||||
dataDir := filepath.Join(cmd.ModelsPath, "..", "data")
|
||||
httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs())
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting HTTP file transfer server: %w", err)
|
||||
}
|
||||
|
||||
// Connect to NATS
|
||||
xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL))
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
if err != nil {
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
defer natsClient.Close()
|
||||
|
||||
// Start heartbeat goroutine (after NATS is connected so IsConnected check works)
|
||||
go func() {
|
||||
ticker := time.NewTicker(heartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-shutdownCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !natsClient.IsConnected() {
|
||||
xlog.Warn("Skipping heartbeat: NATS disconnected")
|
||||
continue
|
||||
}
|
||||
body := cmd.heartbeatBody()
|
||||
if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil {
|
||||
xlog.Warn("Heartbeat failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Process supervisor — manages multiple backend gRPC processes on different ports
|
||||
basePort := 50051
|
||||
if cmd.Addr != "" {
|
||||
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
|
||||
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
basePort = p
|
||||
}
|
||||
}
|
||||
}
|
||||
// Buffered so NATS stop handler can send without blocking
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Set the registration token once before any backends are started
|
||||
if cmd.RegistrationToken != "" {
|
||||
os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken)
|
||||
}
|
||||
|
||||
supervisor := &backendSupervisor{
|
||||
cmd: cmd,
|
||||
ml: ml,
|
||||
systemState: systemState,
|
||||
galleries: galleries,
|
||||
nodeID: nodeID,
|
||||
nats: natsClient,
|
||||
sigCh: sigCh,
|
||||
processes: make(map[string]*backendProcess),
|
||||
nextPort: basePort,
|
||||
}
|
||||
supervisor.subscribeLifecycleEvents()
|
||||
|
||||
// Subscribe to file staging NATS subjects if S3 is configured
|
||||
if cmd.StorageURL != "" {
|
||||
if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil {
|
||||
xlog.Error("Failed to subscribe to file staging subjects", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Worker ready, waiting for backend.install events")
|
||||
<-sigCh
|
||||
|
||||
xlog.Info("Shutting down worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
supervisor.stopAllBackends()
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// subscribeFileStaging subscribes to NATS file staging subjects for this node.
|
||||
func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error {
|
||||
// Create FileManager with same S3 config as the frontend
|
||||
// TODO: propagate a caller-provided context once WorkerCMD carries one
|
||||
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
|
||||
Endpoint: cmd.StorageURL,
|
||||
Region: cmd.StorageRegion,
|
||||
Bucket: cmd.StorageBucket,
|
||||
AccessKeyID: cmd.StorageAccessKey,
|
||||
SecretAccessKey: cmd.StorageSecretKey,
|
||||
ForcePathStyle: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing S3 store: %w", err)
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache")
|
||||
fm, err := storage.NewFileManager(s3Store, cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing file manager: %w", err)
|
||||
}
|
||||
|
||||
// Subscribe: files.ensure — download S3 key to local, reply with local path
|
||||
natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) {
|
||||
var req struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
replyJSON(reply, map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
localPath, err := fm.Download(context.Background(), req.Key)
|
||||
if err != nil {
|
||||
xlog.Error("File ensure failed", "key", req.Key, "error", err)
|
||||
replyJSON(reply, map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Debug("File ensured locally", "key", req.Key, "path", localPath)
|
||||
replyJSON(reply, map[string]string{"local_path": localPath})
|
||||
})
|
||||
|
||||
// Subscribe: files.stage — upload local path to S3, reply with key
|
||||
natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) {
|
||||
var req struct {
|
||||
LocalPath string `json:"local_path"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
replyJSON(reply, map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
allowedDirs := []string{cacheDir}
|
||||
if cmd.ModelsPath != "" {
|
||||
allowedDirs = append(allowedDirs, cmd.ModelsPath)
|
||||
}
|
||||
if !isPathAllowed(req.LocalPath, allowedDirs) {
|
||||
replyJSON(reply, map[string]string{"error": "path outside allowed directories"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil {
|
||||
xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err)
|
||||
replyJSON(reply, map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key)
|
||||
replyJSON(reply, map[string]string{"key": req.Key})
|
||||
})
|
||||
|
||||
// Subscribe: files.temp — allocate temp file, reply with local path
|
||||
natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) {
|
||||
tmpDir := filepath.Join(cacheDir, "staging-tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0750); err != nil {
|
||||
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp")
|
||||
if err != nil {
|
||||
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)})
|
||||
return
|
||||
}
|
||||
localPath := f.Name()
|
||||
f.Close()
|
||||
|
||||
xlog.Debug("Allocated temp file", "path", localPath)
|
||||
replyJSON(reply, map[string]string{"local_path": localPath})
|
||||
})
|
||||
|
||||
// Subscribe: files.listdir — list files in a local directory, reply with relative paths
|
||||
natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) {
|
||||
var req struct {
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
replyJSON(reply, map[string]any{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve key prefix to local directory
|
||||
dirPath := filepath.Join(cacheDir, req.KeyPrefix)
|
||||
if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" {
|
||||
dirPath = filepath.Join(cmd.ModelsPath, rel)
|
||||
} else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok {
|
||||
dirPath = filepath.Join(cacheDir, "..", "data", rel)
|
||||
}
|
||||
|
||||
// Sanitize to prevent directory traversal via crafted key_prefix
|
||||
dirPath = filepath.Clean(dirPath)
|
||||
cleanCache := filepath.Clean(cacheDir)
|
||||
cleanModels := filepath.Clean(cmd.ModelsPath)
|
||||
cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data"))
|
||||
if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) ||
|
||||
dirPath == cleanCache ||
|
||||
(cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) ||
|
||||
dirPath == cleanModels ||
|
||||
strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) ||
|
||||
dirPath == cleanData) {
|
||||
replyJSON(reply, map[string]any{"error": "invalid key prefix"})
|
||||
return
|
||||
}
|
||||
|
||||
var files []string
|
||||
filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !d.IsDir() {
|
||||
rel, err := filepath.Rel(dirPath, path)
|
||||
if err == nil {
|
||||
files = append(files, rel)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files))
|
||||
replyJSON(reply, map[string]any{"files": files})
|
||||
})
|
||||
|
||||
xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// replyJSON marshals v to JSON and calls the reply function.
|
||||
func replyJSON(reply func([]byte), v any) {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal NATS reply", "error", err)
|
||||
data = []byte(`{"error":"internal marshal error"}`)
|
||||
}
|
||||
reply(data)
|
||||
}
|
||||
|
||||
// backendProcess represents a single gRPC backend process.
|
||||
type backendProcess struct {
|
||||
proc *process.Process
|
||||
backend string
|
||||
addr string // gRPC address (host:port)
|
||||
}
|
||||
|
||||
// backendSupervisor manages multiple backend gRPC processes on different ports.
|
||||
// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port.
|
||||
type backendSupervisor struct {
|
||||
cmd *WorkerCMD
|
||||
ml *model.ModelLoader
|
||||
systemState *system.SystemState
|
||||
galleries []config.Gallery
|
||||
nodeID string
|
||||
nats messaging.MessagingClient
|
||||
sigCh chan<- os.Signal // send shutdown signal instead of os.Exit
|
||||
|
||||
mu sync.Mutex
|
||||
processes map[string]*backendProcess // key: backend name
|
||||
nextPort int // next available port for new backends
|
||||
freePorts []int // ports freed by stopBackend, reused before nextPort
|
||||
}
|
||||
|
||||
// startBackend starts a gRPC backend process on a dynamically allocated port.
|
||||
// Returns the gRPC address.
|
||||
func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) {
|
||||
s.mu.Lock()
|
||||
|
||||
// Already running?
|
||||
if bp, ok := s.processes[backend]; ok {
|
||||
if bp.proc != nil && bp.proc.IsAlive() {
|
||||
s.mu.Unlock()
|
||||
return bp.addr, nil
|
||||
}
|
||||
// Process died — clean up and restart
|
||||
xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend)
|
||||
delete(s.processes, backend)
|
||||
}
|
||||
|
||||
// Allocate port — recycle freed ports first, then grow upward from basePort
|
||||
var port int
|
||||
if len(s.freePorts) > 0 {
|
||||
port = s.freePorts[len(s.freePorts)-1]
|
||||
s.freePorts = s.freePorts[:len(s.freePorts)-1]
|
||||
} else {
|
||||
port = s.nextPort
|
||||
s.nextPort++
|
||||
}
|
||||
bindAddr := fmt.Sprintf("0.0.0.0:%d", port)
|
||||
clientAddr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
|
||||
proc, err := s.ml.StartProcess(backendPath, backend, bindAddr)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("starting backend process: %w", err)
|
||||
}
|
||||
|
||||
s.processes[backend] = &backendProcess{
|
||||
proc: proc,
|
||||
backend: backend,
|
||||
addr: clientAddr,
|
||||
}
|
||||
xlog.Info("Backend process started", "backend", backend, "addr", clientAddr)
|
||||
|
||||
// Capture reference before unlocking for race-safe health check.
|
||||
// Another goroutine could stopBackend and recycle the port while we poll.
|
||||
bp := s.processes[backend]
|
||||
s.mu.Unlock()
|
||||
|
||||
// Wait for the gRPC server to be ready
|
||||
client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||
for range 20 {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
if ok, _ := client.HealthCheck(ctx); ok {
|
||||
cancel()
|
||||
// Verify the process wasn't stopped/replaced while health-checking
|
||||
s.mu.Lock()
|
||||
currentBP, exists := s.processes[backend]
|
||||
s.mu.Unlock()
|
||||
if !exists || currentBP != bp {
|
||||
return "", fmt.Errorf("backend %s was stopped during startup", backend)
|
||||
}
|
||||
xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr)
|
||||
return clientAddr, nil
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr)
|
||||
return clientAddr, nil
|
||||
}
|
||||
|
||||
// stopBackend stops a specific backend's gRPC process.
|
||||
func (s *backendSupervisor) stopBackend(backend string) {
|
||||
s.mu.Lock()
|
||||
bp, ok := s.processes[backend]
|
||||
if !ok || bp.proc == nil {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Clean up map and recycle port while holding lock
|
||||
delete(s.processes, backend)
|
||||
if _, portStr, err := net.SplitHostPort(bp.addr); err == nil {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
s.freePorts = append(s.freePorts, p)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Network I/O outside the lock
|
||||
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
||||
if err := bp.proc.Stop(); err != nil {
|
||||
xlog.Error("Error stopping backend process", "backend", backend, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stopAllBackends stops all running backend processes.
|
||||
func (s *backendSupervisor) stopAllBackends() {
|
||||
s.mu.Lock()
|
||||
backends := slices.Collect(maps.Keys(s.processes))
|
||||
s.mu.Unlock()
|
||||
|
||||
for _, b := range backends {
|
||||
s.stopBackend(b)
|
||||
}
|
||||
}
|
||||
|
||||
// isRunning returns whether a specific backend process is currently running.
|
||||
func (s *backendSupervisor) isRunning(backend string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
bp, ok := s.processes[backend]
|
||||
return ok && bp.proc != nil && bp.proc.IsAlive()
|
||||
}
|
||||
|
||||
// getAddr returns the gRPC address for a running backend, or empty string.
|
||||
func (s *backendSupervisor) getAddr(backend string) string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if bp, ok := s.processes[backend]; ok {
|
||||
return bp.addr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// installBackend handles the backend.install flow:
|
||||
// 1. If already running for this model, return existing address
|
||||
// 2. Install backend from gallery (if not already installed)
|
||||
// 3. Find backend binary
|
||||
// 4. Start gRPC process on a new port
|
||||
// Returns the gRPC address of the backend process.
|
||||
func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) {
|
||||
// Process key: use ModelID if provided (per-model process), else backend name
|
||||
processKey := req.ModelID
|
||||
if processKey == "" {
|
||||
processKey = req.Backend
|
||||
}
|
||||
|
||||
// If already running for this model, return its address
|
||||
if addr := s.getAddr(processKey); addr != "" {
|
||||
xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr)
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Parse galleries from request (override local config if provided)
|
||||
galleries := s.galleries
|
||||
if req.BackendGalleries != "" {
|
||||
var reqGalleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil {
|
||||
galleries = reqGalleries
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find the backend binary
|
||||
backendPath := s.findBackend(req.Backend)
|
||||
if backendPath == "" {
|
||||
// Backend not found locally — try auto-installing from gallery
|
||||
xlog.Info("Backend not found locally, attempting gallery install", "backend", req.Backend)
|
||||
if err := gallery.InstallBackendFromGallery(
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, false,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
// Re-register after install and retry
|
||||
gallery.RegisterBackends(s.systemState, s.ml)
|
||||
backendPath = s.findBackend(req.Backend)
|
||||
}
|
||||
|
||||
if backendPath == "" {
|
||||
return "", fmt.Errorf("backend %q not found after install attempt", req.Backend)
|
||||
}
|
||||
|
||||
xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey)
|
||||
|
||||
// Start the gRPC process on a new port (keyed by model, not just backend)
|
||||
return s.startBackend(processKey, backendPath)
|
||||
}
|
||||
|
||||
// findBackend looks for the backend binary in the backends path and system path.
|
||||
func (s *backendSupervisor) findBackend(backend string) string {
|
||||
candidates := []string{
|
||||
filepath.Join(s.cmd.BackendsPath, backend),
|
||||
filepath.Join(s.cmd.BackendsPath, backend, backend),
|
||||
filepath.Join(s.cmd.BackendsSystemPath, backend),
|
||||
filepath.Join(s.cmd.BackendsSystemPath, backend, backend),
|
||||
}
|
||||
if uri := s.ml.GetExternalBackend(backend); uri != "" {
|
||||
if fi, err := os.Stat(uri); err == nil && !fi.IsDir() {
|
||||
return uri
|
||||
}
|
||||
}
|
||||
for _, path := range candidates {
|
||||
fi, err := os.Stat(path)
|
||||
if err == nil && !fi.IsDir() {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// subscribeLifecycleEvents subscribes to NATS backend lifecycle events.
|
||||
func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
// backend.install — install backend + start gRPC process (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS backend.install event")
|
||||
var req messaging.BackendInstallRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
addr, err := s.installBackend(req)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to install backend via NATS", "error", err)
|
||||
resp := messaging.BackendInstallReply{Success: false, Error: err.Error()}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Return the gRPC address so the router knows which port to use
|
||||
advertiseAddr := addr
|
||||
if s.cmd.AdvertiseAddr != "" {
|
||||
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
|
||||
advertiseAddr = net.JoinHostPort(advertiseHost, port)
|
||||
}
|
||||
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
|
||||
replyJSON(reply, resp)
|
||||
})
|
||||
|
||||
// backend.stop — stop a specific backend process
|
||||
s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) {
|
||||
// Try to parse backend name from payload; if empty, stop all
|
||||
var req struct {
|
||||
Backend string `json:"backend"`
|
||||
}
|
||||
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
|
||||
xlog.Info("Received NATS backend.stop event", "backend", req.Backend)
|
||||
s.stopBackend(req.Backend)
|
||||
} else {
|
||||
xlog.Info("Received NATS backend.stop event (all)")
|
||||
s.stopAllBackends()
|
||||
}
|
||||
})
|
||||
|
||||
// backend.delete — stop backend + delete files (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS backend.delete event")
|
||||
var req messaging.BackendDeleteRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Stop if running this backend
|
||||
if s.isRunning(req.Backend) {
|
||||
s.stopBackend(req.Backend)
|
||||
}
|
||||
|
||||
// Delete the backend files
|
||||
if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil {
|
||||
xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err)
|
||||
resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Re-register backends after deletion
|
||||
gallery.RegisterBackends(s.systemState, s.ml)
|
||||
|
||||
resp := messaging.BackendDeleteReply{Success: true}
|
||||
replyJSON(reply, resp)
|
||||
})
|
||||
|
||||
// backend.list — list installed backends (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS backend.list event")
|
||||
backends, err := gallery.ListSystemBackends(s.systemState)
|
||||
if err != nil {
|
||||
resp := messaging.BackendListReply{Error: err.Error()}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
var infos []messaging.NodeBackendInfo
|
||||
for name, b := range backends {
|
||||
info := messaging.NodeBackendInfo{
|
||||
Name: name,
|
||||
IsSystem: b.IsSystem,
|
||||
IsMeta: b.IsMeta,
|
||||
}
|
||||
if b.Metadata != nil {
|
||||
info.InstalledAt = b.Metadata.InstalledAt
|
||||
info.GalleryURL = b.Metadata.GalleryURL
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
|
||||
resp := messaging.BackendListReply{Backends: infos}
|
||||
replyJSON(reply, resp)
|
||||
})
|
||||
|
||||
// model.unload — call gRPC Free() to release GPU memory (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS model.unload event")
|
||||
var req messaging.ModelUnloadRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Find the backend address for this model's backend type
|
||||
// The request includes an Address field if the router knows which process to target
|
||||
targetAddr := req.Address
|
||||
if targetAddr == "" {
|
||||
// Fallback: try all running backends
|
||||
s.mu.Lock()
|
||||
for _, bp := range s.processes {
|
||||
targetAddr = bp.addr
|
||||
break
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
if targetAddr != "" {
|
||||
// Best-effort gRPC Free()
|
||||
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := messaging.ModelUnloadReply{Success: true}
|
||||
replyJSON(reply, resp)
|
||||
})
|
||||
|
||||
// model.delete — remove model files from disk (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS model.delete event")
|
||||
var req messaging.ModelDeleteRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil {
|
||||
xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err)
|
||||
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
replyJSON(reply, messaging.ModelDeleteReply{Success: true})
|
||||
})
|
||||
|
||||
// stop — trigger the normal shutdown path via sigCh so deferred cleanup runs
|
||||
s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) {
|
||||
xlog.Info("Received NATS stop event — signaling shutdown")
|
||||
select {
|
||||
case s.sigCh <- syscall.SIGTERM:
|
||||
default:
|
||||
xlog.Debug("Shutdown already signaled, ignoring duplicate stop")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// advertiseAddr returns the address the frontend should use to reach this node.
|
||||
func (cmd *WorkerCMD) advertiseAddr() string {
|
||||
if cmd.AdvertiseAddr != "" {
|
||||
return cmd.AdvertiseAddr
|
||||
}
|
||||
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||
if ok && (host == "0.0.0.0" || host == "") {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname + ":" + port
|
||||
}
|
||||
}
|
||||
return cmd.Addr
|
||||
}
|
||||
|
||||
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
|
||||
// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports
|
||||
// which grow upward from basePort.
|
||||
func (cmd *WorkerCMD) resolveHTTPAddr() string {
|
||||
if cmd.HTTPAddr != "" {
|
||||
return cmd.HTTPAddr
|
||||
}
|
||||
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||
if !ok {
|
||||
return "0.0.0.0:50050"
|
||||
}
|
||||
portNum, _ := strconv.Atoi(port)
|
||||
return fmt.Sprintf("%s:%d", host, portNum-1)
|
||||
}
|
||||
|
||||
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
|
||||
// this node for file transfer.
|
||||
func (cmd *WorkerCMD) advertiseHTTPAddr() string {
|
||||
if cmd.AdvertiseHTTPAddr != "" {
|
||||
return cmd.AdvertiseHTTPAddr
|
||||
}
|
||||
httpAddr := cmd.resolveHTTPAddr()
|
||||
host, port, ok := strings.Cut(httpAddr, ":")
|
||||
if ok && (host == "0.0.0.0" || host == "") {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname + ":" + port
|
||||
}
|
||||
}
|
||||
return httpAddr
|
||||
}
|
||||
|
||||
// registrationBody builds the JSON body for node registration.
|
||||
func (cmd *WorkerCMD) registrationBody() map[string]any {
|
||||
nodeName := cmd.NodeName
|
||||
if nodeName == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
nodeName = fmt.Sprintf("node-%d", os.Getpid())
|
||||
} else {
|
||||
nodeName = hostname
|
||||
}
|
||||
}
|
||||
|
||||
// Detect GPU info for VRAM-aware scheduling
|
||||
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
|
||||
gpuVendor, _ := xsysinfo.DetectGPUVendor()
|
||||
|
||||
body := map[string]any{
|
||||
"name": nodeName,
|
||||
"address": cmd.advertiseAddr(),
|
||||
"http_address": cmd.advertiseHTTPAddr(),
|
||||
"total_vram": totalVRAM,
|
||||
"available_vram": totalVRAM, // initially all VRAM is available
|
||||
"gpu_vendor": gpuVendor,
|
||||
}
|
||||
|
||||
// If no GPU detected, report system RAM so the scheduler/UI has capacity info
|
||||
if totalVRAM == 0 {
|
||||
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
|
||||
body["total_ram"] = ramInfo.Total
|
||||
body["available_ram"] = ramInfo.Available
|
||||
}
|
||||
}
|
||||
if cmd.RegistrationToken != "" {
|
||||
body["token"] = cmd.RegistrationToken
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads.
|
||||
func (cmd *WorkerCMD) heartbeatBody() map[string]any {
|
||||
var availVRAM uint64
|
||||
aggregate := xsysinfo.GetGPUAggregateInfo()
|
||||
if aggregate.TotalVRAM > 0 {
|
||||
availVRAM = aggregate.FreeVRAM
|
||||
} else {
|
||||
// Fallback: report total as available (no usage tracking possible)
|
||||
availVRAM, _ = xsysinfo.TotalAvailableVRAM()
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"available_vram": availVRAM,
|
||||
}
|
||||
|
||||
// If no GPU, report system RAM usage instead
|
||||
if aggregate.TotalVRAM == 0 {
|
||||
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
|
||||
body["available_ram"] = ramInfo.Available
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
272
core/cli/workerregistry/client.go
Normal file
272
core/cli/workerregistry/client.go
Normal file
@@ -0,0 +1,272 @@
|
||||
// Package workerregistry provides a shared HTTP client for worker node
|
||||
// registration, heartbeating, draining, and deregistration against a
|
||||
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
|
||||
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// RegistrationClient talks to the frontend's /api/node/* endpoints.
|
||||
type RegistrationClient struct {
|
||||
FrontendURL string
|
||||
RegistrationToken string
|
||||
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
|
||||
client *http.Client
|
||||
clientOnce sync.Once
|
||||
}
|
||||
|
||||
// httpTimeout returns the configured timeout or a sensible default.
|
||||
func (c *RegistrationClient) httpTimeout() time.Duration {
|
||||
if c.HTTPTimeout > 0 {
|
||||
return c.HTTPTimeout
|
||||
}
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
// httpClient returns the shared HTTP client, initializing it on first use.
|
||||
func (c *RegistrationClient) httpClient() *http.Client {
|
||||
c.clientOnce.Do(func() {
|
||||
c.client = &http.Client{Timeout: c.httpTimeout()}
|
||||
})
|
||||
return c.client
|
||||
}
|
||||
|
||||
// baseURL returns FrontendURL with any trailing slash stripped.
|
||||
func (c *RegistrationClient) baseURL() string {
|
||||
return strings.TrimRight(c.FrontendURL, "/")
|
||||
}
|
||||
|
||||
// setAuth adds an Authorization header when a token is configured.
|
||||
func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
if c.RegistrationToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating heartbeat request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
|
||||
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
|
||||
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
body := bodyFn()
|
||||
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
|
||||
xlog.Warn("Heartbeat failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drain sets the node to draining status via POST /api/node/:id/drain.
|
||||
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
|
||||
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating drain request: %w", err)
|
||||
}
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitForDrain polls GET /api/node/:id/models until all models report 0
|
||||
// in-flight requests, or until timeout elapses.
|
||||
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
|
||||
url := c.baseURL() + "/api/node/" + nodeID + "/models"
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to create drain poll request", "error", err)
|
||||
return
|
||||
}
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("Drain poll failed, will retry", "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
xlog.Warn("Drain wait cancelled")
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
continue
|
||||
}
|
||||
var models []struct {
|
||||
InFlight int `json:"in_flight"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&models)
|
||||
resp.Body.Close()
|
||||
|
||||
total := 0
|
||||
for _, m := range models {
|
||||
total += m.InFlight
|
||||
}
|
||||
if total == 0 {
|
||||
xlog.Info("All in-flight requests drained")
|
||||
return
|
||||
}
|
||||
xlog.Info("Waiting for in-flight requests", "count", total)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
xlog.Warn("Drain wait cancelled")
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
xlog.Warn("Drain timeout reached, proceeding with shutdown")
|
||||
}
|
||||
|
||||
// Deregister marks the node as offline via POST /api/node/:id/deregister.
|
||||
// The node row is preserved in the database so re-registration restores
|
||||
// approval status.
|
||||
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
|
||||
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating deregister request: %w", err)
|
||||
}
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GracefulDeregister performs drain -> wait -> deregister in sequence.
|
||||
// This is the standard shutdown sequence for backend workers.
|
||||
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
|
||||
if c.FrontendURL == "" || nodeID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := c.Drain(ctx, nodeID); err != nil {
|
||||
xlog.Warn("Failed to set drain status", "error", err)
|
||||
} else {
|
||||
c.WaitForDrain(ctx, nodeID, 30*time.Second)
|
||||
}
|
||||
|
||||
if err := c.Deregister(ctx, nodeID); err != nil {
|
||||
xlog.Error("Failed to deregister", "error", err)
|
||||
} else {
|
||||
xlog.Info("Deregistered from frontend")
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
|
||||
}
|
||||
|
||||
// Helper function to perform a request without expecting a response body
|
||||
func (c *StoreClient) doRequest(path string, data interface{}) error {
|
||||
func (c *StoreClient) doRequest(path string, data any) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -120,7 +120,7 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
|
||||
}
|
||||
|
||||
// Helper function to perform a request and parse the response body
|
||||
func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) {
|
||||
func (c *StoreClient) doRequestWithResponse(path string, data any) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -83,8 +83,8 @@ type ApplicationConfig struct {
|
||||
|
||||
APIAddress string
|
||||
|
||||
LlamaCPPTunnelCallback func(tunnels []string)
|
||||
MLXTunnelCallback func(tunnels []string)
|
||||
LlamaCPPTunnelCallback func(tunnels []string)
|
||||
MLXTunnelCallback func(tunnels []string)
|
||||
|
||||
DisableRuntimeSettings bool
|
||||
|
||||
@@ -99,47 +99,50 @@ type ApplicationConfig struct {
|
||||
|
||||
// Authentication & Authorization
|
||||
Auth AuthConfig
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
}
|
||||
|
||||
// AuthConfig holds configuration for user authentication and authorization.
|
||||
type AuthConfig struct {
|
||||
Enabled bool
|
||||
DatabaseURL string // "postgres://..." or file path for SQLite
|
||||
GitHubClientID string
|
||||
GitHubClientSecret string
|
||||
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
|
||||
AdminEmail string // auto-promote to admin on login
|
||||
RegistrationMode string // "open", "approval" (default when empty), "invite"
|
||||
DisableLocalAuth bool // disable local email/password registration and login
|
||||
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
|
||||
Enabled bool
|
||||
DatabaseURL string // "postgres://..." or file path for SQLite
|
||||
GitHubClientID string
|
||||
GitHubClientSecret string
|
||||
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
|
||||
AdminEmail string // auto-promote to admin on login
|
||||
RegistrationMode string // "open", "approval" (default when empty), "invite"
|
||||
DisableLocalAuth bool // disable local email/password registration and login
|
||||
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
|
||||
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
|
||||
}
|
||||
|
||||
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
|
||||
type AgentPoolConfig struct {
|
||||
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
|
||||
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
|
||||
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
|
||||
APIKey string // default: first API key from LocalAI config
|
||||
DefaultModel string
|
||||
MultimodalModel string
|
||||
TranscriptionModel string
|
||||
TranscriptionLanguage string
|
||||
TTSModel string
|
||||
Timeout string // default: "5m"
|
||||
EnableSkills bool
|
||||
EnableLogs bool
|
||||
CustomActionsDir string
|
||||
CollectionDBPath string
|
||||
VectorEngine string // default: "chromem"
|
||||
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
|
||||
MaxChunkingSize int // default: 400
|
||||
ChunkOverlap int // default: 0
|
||||
DatabaseURL string
|
||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
|
||||
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
|
||||
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
|
||||
APIKey string // default: first API key from LocalAI config
|
||||
DefaultModel string
|
||||
MultimodalModel string
|
||||
TranscriptionModel string
|
||||
TranscriptionLanguage string
|
||||
TTSModel string
|
||||
Timeout string // default: "5m"
|
||||
EnableSkills bool
|
||||
EnableLogs bool
|
||||
CustomActionsDir string
|
||||
CollectionDBPath string
|
||||
VectorEngine string // default: "chromem"
|
||||
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
|
||||
MaxChunkingSize int // default: 400
|
||||
ChunkOverlap int // default: 0
|
||||
DatabaseURL string
|
||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
@@ -155,12 +158,12 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
VectorEngine: "chromem",
|
||||
EmbeddingModel: "granite-embedding-107m-multilingual",
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
VectorEngine: "chromem",
|
||||
EmbeddingModel: "granite-embedding-107m-multilingual",
|
||||
MaxChunkingSize: 400,
|
||||
AgentHubURL: "https://agenthub.localai.io",
|
||||
AgentHubURL: "https://agenthub.localai.io",
|
||||
},
|
||||
PathWithoutAuth: []string{
|
||||
"/static/",
|
||||
@@ -904,40 +907,40 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
WatchdogBusyEnabled: &watchdogBusy,
|
||||
WatchdogIdleTimeout: &idleTimeout,
|
||||
WatchdogBusyTimeout: &busyTimeout,
|
||||
WatchdogInterval: &watchdogInterval,
|
||||
SingleBackend: &singleBackend,
|
||||
MaxActiveBackends: &maxActiveBackends,
|
||||
ParallelBackendRequests: ¶llelBackendRequests,
|
||||
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
|
||||
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
|
||||
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
|
||||
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
|
||||
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
|
||||
Threads: &threads,
|
||||
ContextSize: &contextSize,
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
CSRF: &csrf,
|
||||
CORSAllowOrigins: &corsAllowOrigins,
|
||||
P2PToken: &p2pToken,
|
||||
P2PNetworkID: &p2pNetworkID,
|
||||
Federated: &federated,
|
||||
Galleries: &galleries,
|
||||
BackendGalleries: &backendGalleries,
|
||||
AutoloadGalleries: &autoloadGalleries,
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
WatchdogBusyEnabled: &watchdogBusy,
|
||||
WatchdogIdleTimeout: &idleTimeout,
|
||||
WatchdogBusyTimeout: &busyTimeout,
|
||||
WatchdogInterval: &watchdogInterval,
|
||||
SingleBackend: &singleBackend,
|
||||
MaxActiveBackends: &maxActiveBackends,
|
||||
ParallelBackendRequests: ¶llelBackendRequests,
|
||||
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
|
||||
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
|
||||
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
|
||||
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
|
||||
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
|
||||
Threads: &threads,
|
||||
ContextSize: &contextSize,
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
CSRF: &csrf,
|
||||
CORSAllowOrigins: &corsAllowOrigins,
|
||||
P2PToken: &p2pToken,
|
||||
P2PNetworkID: &p2pNetworkID,
|
||||
Federated: &federated,
|
||||
Galleries: &galleries,
|
||||
BackendGalleries: &backendGalleries,
|
||||
AutoloadGalleries: &autoloadGalleries,
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
AgentPoolEnabled: &agentPoolEnabled,
|
||||
AgentPoolDefaultModel: &agentPoolDefaultModel,
|
||||
AgentPoolEmbeddingModel: &agentPoolEmbeddingModel,
|
||||
|
||||
@@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
F16: true,
|
||||
Debug: true,
|
||||
CORS: true,
|
||||
DisableCSRF: true,
|
||||
DisableCSRF: true,
|
||||
CORSAllowOrigins: "https://example.com",
|
||||
P2PToken: "test-token",
|
||||
P2PNetworkID: "test-network",
|
||||
@@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
F16: true,
|
||||
Debug: false,
|
||||
CORS: true,
|
||||
DisableCSRF: false,
|
||||
DisableCSRF: false,
|
||||
CORSAllowOrigins: "https://test.com",
|
||||
P2PToken: "round-trip-token",
|
||||
P2PNetworkID: "round-trip-network",
|
||||
|
||||
188
core/config/distributed_config.go
Normal file
188
core/config/distributed_config.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// DistributedConfig holds configuration for horizontal scaling mode.
|
||||
// When Enabled is true, PostgreSQL and NATS are required.
|
||||
type DistributedConfig struct {
|
||||
Enabled bool // --distributed / LOCALAI_DISTRIBUTED
|
||||
InstanceID string // --instance-id / LOCALAI_INSTANCE_ID (auto-generated UUID if empty)
|
||||
NatsURL string // --nats-url / LOCALAI_NATS_URL
|
||||
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
|
||||
StorageAccessKey string // --storage-access-key / LOCALAI_STORAGE_ACCESS_KEY
|
||||
StorageSecretKey string // --storage-secret-key / LOCALAI_STORAGE_SECRET_KEY
|
||||
|
||||
// Timeout configuration (all have sensible defaults — zero means use default)
|
||||
MCPToolTimeout time.Duration // MCP tool execution timeout (default 360s)
|
||||
MCPDiscoveryTimeout time.Duration // MCP discovery timeout (default 60s)
|
||||
WorkerWaitTimeout time.Duration // Max wait for healthy worker at startup (default 5m)
|
||||
DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s)
|
||||
HealthCheckInterval time.Duration // Health monitor check interval (default 15s)
|
||||
StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s)
|
||||
PerModelHealthCheck bool // Enable per-model backend health checking (default false)
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||
|
||||
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
// It returns nil if distributed mode is disabled.
|
||||
func (c DistributedConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
if c.NatsURL == "" {
|
||||
return fmt.Errorf("distributed mode requires --nats-url / LOCALAI_NATS_URL")
|
||||
}
|
||||
// S3 credentials must be paired
|
||||
if (c.StorageAccessKey != "" && c.StorageSecretKey == "") ||
|
||||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
|
||||
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
|
||||
}
|
||||
// Warn about missing registration token (not an error)
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
}
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
} {
|
||||
if d < 0 {
|
||||
return fmt.Errorf("%s must not be negative", name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Distributed config options
|
||||
|
||||
var EnableDistributed = func(o *ApplicationConfig) {
|
||||
o.Distributed.Enabled = true
|
||||
}
|
||||
|
||||
func WithDistributedInstanceID(id string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.InstanceID = id
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithRegistrationToken(token string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.RegistrationToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageBucket(bucket string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageBucket = bucket
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageRegion(region string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageRegion = region
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageAccessKey(key string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageAccessKey = key
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageSecretKey(key string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageSecretKey = key
|
||||
}
|
||||
}
|
||||
|
||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// Defaults for distributed timeouts.
|
||||
const (
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
)
|
||||
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||
}
|
||||
|
||||
// MCPDiscoveryTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPDiscoveryTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPDiscoveryTimeout, DefaultMCPDiscoveryTimeout)
|
||||
}
|
||||
|
||||
// WorkerWaitTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) WorkerWaitTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.WorkerWaitTimeout, DefaultWorkerWaitTimeout)
|
||||
}
|
||||
|
||||
// DrainTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) DrainTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.DrainTimeout, DefaultDrainTimeout)
|
||||
}
|
||||
|
||||
// HealthCheckIntervalOrDefault returns the configured interval or the default.
|
||||
func (c DistributedConfig) HealthCheckIntervalOrDefault() time.Duration {
|
||||
return cmp.Or(c.HealthCheckInterval, DefaultHealthCheckInterval)
|
||||
}
|
||||
|
||||
// StaleNodeThresholdOrDefault returns the configured threshold or the default.
|
||||
func (c DistributedConfig) StaleNodeThresholdOrDefault() time.Duration {
|
||||
return cmp.Or(c.StaleNodeThreshold, DefaultStaleNodeThreshold)
|
||||
}
|
||||
|
||||
// MCPCIJobTimeoutOrDefault returns the configured MCP CI job timeout or the default.
|
||||
func (c DistributedConfig) MCPCIJobTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPCIJobTimeout, DefaultMCPCIJobTimeout)
|
||||
}
|
||||
|
||||
// MaxUploadSizeOrDefault returns the configured max upload size or the default.
|
||||
func (c DistributedConfig) MaxUploadSizeOrDefault() int64 {
|
||||
return cmp.Or(c.MaxUploadSize, DefaultMaxUploadSize)
|
||||
}
|
||||
@@ -46,11 +46,11 @@ type ModelConfig struct {
|
||||
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||
InputToken [][]int `yaml:"-" json:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
|
||||
ResponseFormat string `yaml:"-" json:"-"`
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||
InputToken [][]int `yaml:"-" json:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
|
||||
ResponseFormat string `yaml:"-" json:"-"`
|
||||
ResponseFormatMap map[string]any `yaml:"-" json:"-"`
|
||||
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
@@ -105,6 +105,11 @@ type AgentConfig struct {
|
||||
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
|
||||
}
|
||||
|
||||
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
|
||||
func (c MCPConfig) HasMCPServers() bool {
|
||||
return c.Servers != "" || c.Stdio != ""
|
||||
}
|
||||
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
||||
var remote MCPGenericConfig[MCPRemoteServers]
|
||||
var stdio MCPGenericConfig[MCPSTDIOServers]
|
||||
@@ -619,15 +624,32 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
|
||||
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
|
||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
// Backends that are clearly not text-generation
|
||||
nonTextGenBackends := []string{
|
||||
"whisper", "piper", "kokoro",
|
||||
"diffusers", "stablediffusion", "stablediffusion-ggml",
|
||||
"rerankers", "silero-vad", "rfdetr",
|
||||
"transformers-musicgen", "ace-step", "acestep-cpp",
|
||||
}
|
||||
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
if c.Embeddings != nil && *c.Embeddings {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
|
||||
if c.TemplateConfig.Completion == "" {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_EDIT) == FLAG_EDIT {
|
||||
if c.TemplateConfig.Edit == "" {
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
package config
|
||||
|
||||
import "regexp"
|
||||
|
||||
type ModelConfigFilterFn func(string, *ModelConfig) bool
|
||||
|
||||
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
|
||||
|
||||
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
|
||||
if filter == "" {
|
||||
return NoFilterFn, nil
|
||||
}
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config != nil {
|
||||
return rxp.MatchString(config.Name)
|
||||
}
|
||||
return rxp.MatchString(name)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
|
||||
if usecases == FLAG_ANY {
|
||||
return NoFilterFn
|
||||
}
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config == nil {
|
||||
return false // TODO: Potentially make this a param, for now, no known usecase to include
|
||||
}
|
||||
return config.HasUsecases(usecases)
|
||||
}
|
||||
}
|
||||
package config
|
||||
|
||||
import "regexp"
|
||||
|
||||
type ModelConfigFilterFn func(string, *ModelConfig) bool
|
||||
|
||||
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
|
||||
|
||||
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
|
||||
if filter == "" {
|
||||
return NoFilterFn, nil
|
||||
}
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config != nil {
|
||||
return rxp.MatchString(config.Name)
|
||||
}
|
||||
return rxp.MatchString(name)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
|
||||
if usecases == FLAG_ANY {
|
||||
return NoFilterFn
|
||||
}
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config == nil {
|
||||
return false // TODO: Potentially make this a param, for now, no known usecase to include
|
||||
}
|
||||
return config.HasUsecases(usecases)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -215,8 +216,8 @@ func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
|
||||
res = append(res, v)
|
||||
}
|
||||
|
||||
sort.SliceStable(res, func(i, j int) bool {
|
||||
return res[i].Name < res[j].Name
|
||||
slices.SortStableFunc(res, func(a, b ModelConfig) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
@@ -27,15 +27,15 @@ type RuntimeSettings struct {
|
||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||
|
||||
// Eviction settings
|
||||
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
|
||||
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
|
||||
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
|
||||
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
|
||||
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
|
||||
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
|
||||
|
||||
// Performance settings
|
||||
Threads *int `json:"threads,omitempty"`
|
||||
ContextSize *int `json:"context_size,omitempty"`
|
||||
F16 *bool `json:"f16,omitempty"`
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
Threads *int `json:"threads,omitempty"`
|
||||
ContextSize *int `json:"context_size,omitempty"`
|
||||
F16 *bool `json:"f16,omitempty"`
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||
@@ -66,11 +66,11 @@ type RuntimeSettings struct {
|
||||
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
|
||||
|
||||
// Agent Pool settings
|
||||
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
|
||||
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
|
||||
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
|
||||
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
|
||||
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
|
||||
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
|
||||
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
|
||||
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
|
||||
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
|
||||
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
|
||||
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
|
||||
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
|
||||
AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"`
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@ package explorer
|
||||
// A simple JSON database for storing and retrieving p2p network tokens and a name and description.
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sort"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/gofrs/flock"
|
||||
@@ -89,9 +90,8 @@ func (db *Database) TokenList() []string {
|
||||
tokens = append(tokens, k)
|
||||
}
|
||||
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
// sort by token
|
||||
return tokens[i] < tokens[j]
|
||||
slices.SortFunc(tokens, func(a, b string) int {
|
||||
return cmp.Compare(a, b)
|
||||
})
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
|
||||
type modelConfigCacheEntry struct {
|
||||
configMap map[string]interface{}
|
||||
configMap map[string]any
|
||||
lastUpdated time.Time
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func resolveBackend(m *GalleryModel, basePath string) string {
|
||||
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
|
||||
// inside it, and returns the result as a map. Results are cached for 1 hour.
|
||||
// Local file:// URLs skip the cache so edits are picked up immediately.
|
||||
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
|
||||
func fetchModelConfigMap(modelURL, basePath string) map[string]any {
|
||||
// Check cache (skip for file:// URLs so local edits are picked up immediately)
|
||||
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
|
||||
if !isLocal && modelConfigCache.Exists(modelURL) {
|
||||
@@ -75,15 +75,15 @@ func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
|
||||
// Cache the failure for remote URLs to avoid repeated fetch attempts
|
||||
if !isLocal {
|
||||
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
|
||||
configMap: map[string]interface{}{},
|
||||
configMap: map[string]any{},
|
||||
lastUpdated: time.Now(),
|
||||
})
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
// Parse the config_file YAML string into a map
|
||||
configMap := make(map[string]interface{})
|
||||
configMap := make(map[string]any)
|
||||
if modelConfig.ConfigFile != "" {
|
||||
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
|
||||
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
|
||||
@@ -108,13 +108,11 @@ func prefetchModelConfigs(urls []string, basePath string) {
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
for _, url := range urls {
|
||||
wg.Add(1)
|
||||
go func(u string) {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
fetchModelConfigMap(u, basePath)
|
||||
}(url)
|
||||
fetchModelConfigMap(url, basePath)
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
cp "github.com/otiai10/copy"
|
||||
)
|
||||
|
||||
// ErrBackendNotFound is returned when a backend is not found in the system.
|
||||
var ErrBackendNotFound = errors.New("backend not found")
|
||||
|
||||
const (
|
||||
metadataFile = "metadata.json"
|
||||
runFile = "run.sh"
|
||||
@@ -198,9 +201,16 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
} else {
|
||||
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
// Don't remove backendPath here — fallback OCI extractions need the directory to exist
|
||||
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
||||
|
||||
// resetBackendPath cleans up partial state from a failed OCI extraction
|
||||
// so the next download attempt starts fresh. The directory is re-created
|
||||
// because OCI image extractors need it to exist for writing files into.
|
||||
resetBackendPath := func() {
|
||||
os.RemoveAll(backendPath)
|
||||
os.MkdirAll(backendPath, 0750)
|
||||
}
|
||||
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
@@ -210,6 +220,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
resetBackendPath()
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
||||
@@ -221,28 +232,22 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
// Try fallback: replace latestTag + "-" with masterTag + "-" in the URI
|
||||
fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1)
|
||||
if fallbackURI != string(config.URI) {
|
||||
xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
||||
resetBackendPath()
|
||||
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
||||
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
||||
success = true
|
||||
} else {
|
||||
// Try another fallback: add "-" + devSuffix suffix to the backend name
|
||||
// For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development
|
||||
xlog.Info("Fallback URI failed", "fallback", fallbackURI, "error", err)
|
||||
if !strings.Contains(fallbackURI, "-"+devSuffix) {
|
||||
// Extract backend name from URI and add -development
|
||||
parts := strings.Split(fallbackURI, "-")
|
||||
if len(parts) >= 2 {
|
||||
// Find where the backend name ends (usually the last part before the tag)
|
||||
// Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step
|
||||
lastDash := strings.LastIndex(fallbackURI, "-")
|
||||
if lastDash > 0 {
|
||||
devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix
|
||||
xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI)
|
||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
||||
success = true
|
||||
}
|
||||
}
|
||||
resetBackendPath()
|
||||
devFallbackURI := fallbackURI + "-" + devSuffix
|
||||
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
|
||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
||||
success = true
|
||||
} else {
|
||||
xlog.Info("Development fallback URI failed", "fallback", devFallbackURI, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -295,7 +300,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
|
||||
backend, ok := backends.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q not found", name)
|
||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||
}
|
||||
|
||||
if backend.IsSystem {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -106,64 +106,64 @@ func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
|
||||
sort.Slice(gm, func(i, j int) bool {
|
||||
if sortOrder == "asc" {
|
||||
return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
|
||||
} else {
|
||||
return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName())
|
||||
slices.SortFunc(gm, func(a, b T) int {
|
||||
r := strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
|
||||
if sortOrder == "desc" {
|
||||
return -r
|
||||
}
|
||||
return r
|
||||
})
|
||||
return gm
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
|
||||
sort.Slice(gm, func(i, j int) bool {
|
||||
if sortOrder == "asc" {
|
||||
return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name)
|
||||
} else {
|
||||
return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name)
|
||||
slices.SortFunc(gm, func(a, b T) int {
|
||||
r := strings.Compare(strings.ToLower(a.GetGallery().Name), strings.ToLower(b.GetGallery().Name))
|
||||
if sortOrder == "desc" {
|
||||
return -r
|
||||
}
|
||||
return r
|
||||
})
|
||||
return gm
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
|
||||
sort.Slice(gm, func(i, j int) bool {
|
||||
licenseI := gm[i].GetLicense()
|
||||
licenseJ := gm[j].GetLicense()
|
||||
var result bool
|
||||
if licenseI == "" && licenseJ != "" {
|
||||
return sortOrder == "desc"
|
||||
} else if licenseI != "" && licenseJ == "" {
|
||||
return sortOrder == "asc"
|
||||
} else if licenseI == "" && licenseJ == "" {
|
||||
return false
|
||||
slices.SortFunc(gm, func(a, b T) int {
|
||||
licenseA := a.GetLicense()
|
||||
licenseB := b.GetLicense()
|
||||
var r int
|
||||
if licenseA == "" && licenseB != "" {
|
||||
r = 1
|
||||
} else if licenseA != "" && licenseB == "" {
|
||||
r = -1
|
||||
} else {
|
||||
result = strings.ToLower(licenseI) < strings.ToLower(licenseJ)
|
||||
r = strings.Compare(strings.ToLower(licenseA), strings.ToLower(licenseB))
|
||||
}
|
||||
if sortOrder == "desc" {
|
||||
return !result
|
||||
} else {
|
||||
return result
|
||||
return -r
|
||||
}
|
||||
return r
|
||||
})
|
||||
return gm
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
|
||||
sort.Slice(gm, func(i, j int) bool {
|
||||
var result bool
|
||||
slices.SortFunc(gm, func(a, b T) int {
|
||||
var r int
|
||||
// Sort by installed status: installed items first (true > false)
|
||||
if gm[i].GetInstalled() != gm[j].GetInstalled() {
|
||||
result = gm[i].GetInstalled()
|
||||
if a.GetInstalled() != b.GetInstalled() {
|
||||
if a.GetInstalled() {
|
||||
r = -1
|
||||
} else {
|
||||
r = 1
|
||||
}
|
||||
} else {
|
||||
result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
|
||||
r = strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
|
||||
}
|
||||
if sortOrder == "desc" {
|
||||
return !result
|
||||
} else {
|
||||
return result
|
||||
return -r
|
||||
}
|
||||
return r
|
||||
})
|
||||
return gm
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ var _ = Describe("Gallery", func() {
|
||||
|
||||
Describe("ReadConfigFile", func() {
|
||||
It("should read and unmarshal a valid YAML file", func() {
|
||||
testConfig := map[string]interface{}{
|
||||
testConfig := map[string]any{
|
||||
"name": "test-model",
|
||||
"description": "A test model",
|
||||
"license": "MIT",
|
||||
@@ -39,8 +39,8 @@ var _ = Describe("Gallery", func() {
|
||||
err = os.WriteFile(filePath, yamlData, 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var result map[string]interface{}
|
||||
config, err := ReadConfigFile[map[string]interface{}](filePath)
|
||||
var result map[string]any
|
||||
config, err := ReadConfigFile[map[string]any](filePath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(config).NotTo(BeNil())
|
||||
result = *config
|
||||
@@ -50,7 +50,7 @@ var _ = Describe("Gallery", func() {
|
||||
})
|
||||
|
||||
It("should return error when file does not exist", func() {
|
||||
_, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml")
|
||||
_, err := ReadConfigFile[map[string]any]("nonexistent.yaml")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -59,7 +59,7 @@ var _ = Describe("Gallery", func() {
|
||||
err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
_, err = ReadConfigFile[map[string]interface{}](filePath)
|
||||
_, err = ReadConfigFile[map[string]any](filePath)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
@@ -552,32 +552,32 @@ var _ = Describe("Gallery", func() {
|
||||
// Verify first model
|
||||
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
|
||||
Expect(models[0].Overrides).NotTo(BeNil())
|
||||
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
||||
params := models[0].Overrides["parameters"].(map[string]interface{})
|
||||
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
|
||||
params := models[0].Overrides["parameters"].(map[string]any)
|
||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
|
||||
|
||||
// Verify second model (merged)
|
||||
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
|
||||
Expect(models[1].Overrides).NotTo(BeNil())
|
||||
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
||||
params = models[1].Overrides["parameters"].(map[string]interface{})
|
||||
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
|
||||
params = models[1].Overrides["parameters"].(map[string]any)
|
||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||
|
||||
// Simulate the mergo.Merge call that was failing in models.go:251
|
||||
// This should not panic with yaml.v3
|
||||
configMap := make(map[string]interface{})
|
||||
configMap := make(map[string]any)
|
||||
configMap["name"] = "test"
|
||||
configMap["backend"] = "llama-cpp"
|
||||
configMap["parameters"] = map[string]interface{}{
|
||||
configMap["parameters"] = map[string]any{
|
||||
"model": "original.gguf",
|
||||
}
|
||||
|
||||
err = mergo.Merge(&configMap, models[1].Overrides, mergo.WithOverride)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(configMap["parameters"]).NotTo(BeNil())
|
||||
|
||||
|
||||
// Verify the merge worked correctly
|
||||
mergedParams := configMap["parameters"].(map[string]interface{})
|
||||
mergedParams := configMap["parameters"].(map[string]any)
|
||||
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -59,7 +59,7 @@ var _ = Describe("ImportLocalPath", func() {
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||
"peft_type": "LORA",
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
@@ -158,7 +158,7 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]any, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
@@ -239,7 +239,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||
configFilePath := filepath.Join(basePath, name+".yaml")
|
||||
|
||||
// Read and update config file as map[string]interface{}
|
||||
configMap := make(map[string]interface{})
|
||||
configMap := make(map[string]any)
|
||||
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
||||
|
||||
@@ -35,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -43,7 +43,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -95,7 +95,7 @@ var _ = Describe("Model test", func() {
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
@@ -130,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -150,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -158,7 +158,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -180,7 +180,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
type GalleryModel struct {
|
||||
Metadata `json:",inline" yaml:",inline"`
|
||||
// config_file is read in the situation where URL is blank - and therefore this is a base config.
|
||||
ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"`
|
||||
ConfigFile map[string]any `json:"config_file,omitempty" yaml:"config_file,omitempty"`
|
||||
// Overrides are used to override the configuration of the model located at URL
|
||||
Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"`
|
||||
Overrides map[string]any `json:"overrides,omitempty" yaml:"overrides,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GalleryModel) GetInstalled() bool {
|
||||
|
||||
66
core/gallery/worker.go
Normal file
66
core/gallery/worker.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// DeleteStagedModelFiles removes all staged files for a model from a worker's
|
||||
// models directory. Files are expected to be in a subdirectory named after the
|
||||
// model's tracking key (created by stageModelFiles in the router).
|
||||
//
|
||||
// Workers receive model files via S3/HTTP file staging, not gallery install,
|
||||
// so they lack the YAML configs that DeleteModelFromSystem requires.
|
||||
//
|
||||
// Falls back to glob-based cleanup for single-file models or legacy layouts.
|
||||
func DeleteStagedModelFiles(modelsPath, modelName string) error {
|
||||
if modelName == "" {
|
||||
return fmt.Errorf("empty model name")
|
||||
}
|
||||
|
||||
// Clean and validate: resolved path must stay within modelsPath
|
||||
modelPath := filepath.Clean(filepath.Join(modelsPath, modelName))
|
||||
absModels := filepath.Clean(modelsPath)
|
||||
if !strings.HasPrefix(modelPath, absModels+string(filepath.Separator)) {
|
||||
return fmt.Errorf("model name %q escapes models directory", modelName)
|
||||
}
|
||||
|
||||
// Primary: remove the model's subdirectory (contains all staged files)
|
||||
if info, err := os.Stat(modelPath); err == nil && info.IsDir() {
|
||||
return os.RemoveAll(modelPath)
|
||||
}
|
||||
|
||||
// Fallback for single-file models or legacy layouts:
|
||||
// remove exact file match + glob siblings
|
||||
removed := false
|
||||
if _, err := os.Stat(modelPath); err == nil {
|
||||
if err := os.Remove(modelPath); err != nil {
|
||||
xlog.Warn("Failed to remove model file", "path", modelPath, "error", err)
|
||||
} else {
|
||||
removed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Remove sibling files (e.g., model.gguf.mmproj alongside model.gguf)
|
||||
matches, _ := filepath.Glob(modelPath + ".*")
|
||||
for _, m := range matches {
|
||||
clean := filepath.Clean(m)
|
||||
if !strings.HasPrefix(clean, absModels+string(filepath.Separator)) {
|
||||
continue // skip any glob result that escapes
|
||||
}
|
||||
if err := os.Remove(clean); err != nil {
|
||||
xlog.Warn("Failed to remove model-related file", "path", clean, "error", err)
|
||||
} else {
|
||||
removed = true
|
||||
}
|
||||
}
|
||||
|
||||
if !removed {
|
||||
xlog.Debug("No files found to delete for model", "model", modelName, "path", modelPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
99
core/gallery/worker_test.go
Normal file
99
core/gallery/worker_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
)
|
||||
|
||||
func TestDeleteStagedModelFiles(t *testing.T) {
|
||||
t.Run("rejects empty model name", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := gallery.DeleteStagedModelFiles(dir, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty model name")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects path traversal via ..", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := gallery.DeleteStagedModelFiles(dir, "../../etc/passwd")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path traversal attempt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects path traversal via ../foo", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := gallery.DeleteStagedModelFiles(dir, "../foo")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path traversal attempt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removes model subdirectory with all files", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelDir := filepath.Join(dir, "my-model", "sd-cpp", "models")
|
||||
if err := os.MkdirAll(modelDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Create model files in subdirectory
|
||||
os.WriteFile(filepath.Join(modelDir, "flux.gguf"), []byte("model"), 0o644)
|
||||
os.WriteFile(filepath.Join(modelDir, "flux.gguf.mmproj"), []byte("mmproj"), 0o644)
|
||||
|
||||
err := gallery.DeleteStagedModelFiles(dir, "my-model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Entire my-model directory should be gone
|
||||
if _, err := os.Stat(filepath.Join(dir, "my-model")); !os.IsNotExist(err) {
|
||||
t.Fatal("expected model directory to be removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removes single file model", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelFile := filepath.Join(dir, "model.gguf")
|
||||
os.WriteFile(modelFile, []byte("model"), 0o644)
|
||||
|
||||
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
|
||||
t.Fatal("expected model file to be removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removes sibling files via glob", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelFile := filepath.Join(dir, "model.gguf")
|
||||
siblingFile := filepath.Join(dir, "model.gguf.mmproj")
|
||||
os.WriteFile(modelFile, []byte("model"), 0o644)
|
||||
os.WriteFile(siblingFile, []byte("mmproj"), 0o644)
|
||||
|
||||
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
|
||||
t.Fatal("expected model file to be removed")
|
||||
}
|
||||
if _, err := os.Stat(siblingFile); !os.IsNotExist(err) {
|
||||
t.Fatal("expected sibling file to be removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no error when model does not exist", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := gallery.DeleteStagedModelFiles(dir, "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -16,12 +16,17 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/finetune"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/quantization"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -155,7 +160,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
metricsService, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -295,9 +300,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *services.OpCache
|
||||
var opcache *galleryop.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
opcache = galleryop.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||
@@ -305,22 +310,51 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||
// Fine-tuning routes
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
ftService := services.NewFineTuneService(
|
||||
ftService := finetune.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
if d := application.Distributed(); d != nil {
|
||||
ftService.SetNATSClient(d.Nats)
|
||||
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
||||
ftService.SetFineTuneStore(d.DistStores.FineTune)
|
||||
}
|
||||
}
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
|
||||
// Quantization routes
|
||||
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
||||
qService := services.NewQuantizationService(
|
||||
qService := quantization.NewQuantizationService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
||||
|
||||
// Node management routes (distributed mode)
|
||||
distCfg := application.ApplicationConfig().Distributed
|
||||
var registry *nodes.NodeRegistry
|
||||
var remoteUnloader nodes.NodeCommandSender
|
||||
if d := application.Distributed(); d != nil {
|
||||
registry = d.Registry
|
||||
if d.Router != nil {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
if d.Dispatcher != nil {
|
||||
e.GET("/api/agent/jobs/:id/progress", d.Dispatcher.SSEHandler(), mcpJobsMw)
|
||||
}
|
||||
if d.AgentBridge != nil {
|
||||
e.GET("/api/agents/:name/sse/distributed", d.AgentBridge.SSEHandler(), agentsMw)
|
||||
}
|
||||
}
|
||||
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
|
||||
@@ -44,14 +44,14 @@ Say hello.
|
||||
### Response:`
|
||||
|
||||
type modelApplyRequest struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
ConfigURL string `json:"config_url"`
|
||||
Name string `json:"name"`
|
||||
Overrides map[string]interface{} `json:"overrides"`
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
ConfigURL string `json:"config_url"`
|
||||
Name string `json:"name"`
|
||||
Overrides map[string]any `json:"overrides"`
|
||||
}
|
||||
|
||||
func getModelStatus(url string) (response map[string]interface{}) {
|
||||
func getModelStatus(url string) (response map[string]any) {
|
||||
// Create the HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -94,7 +94,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||
return response, err
|
||||
}
|
||||
|
||||
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
|
||||
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]any) {
|
||||
|
||||
//url := "http://localhost:AI/models/apply"
|
||||
|
||||
@@ -336,7 +336,7 @@ var _ = Describe("API test", func() {
|
||||
Name: "bert",
|
||||
URL: bertEmbeddingsURL,
|
||||
},
|
||||
Overrides: map[string]interface{}{"backend": "llama-cpp"},
|
||||
Overrides: map[string]any{"backend": "llama-cpp"},
|
||||
},
|
||||
{
|
||||
Metadata: gallery.Metadata{
|
||||
@@ -344,7 +344,7 @@ var _ = Describe("API test", func() {
|
||||
URL: bertEmbeddingsURL,
|
||||
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
|
||||
},
|
||||
Overrides: map[string]interface{}{"foo": "bar"},
|
||||
Overrides: map[string]any{"foo": "bar"},
|
||||
},
|
||||
}
|
||||
out, err := yaml.Marshal(g)
|
||||
@@ -464,7 +464,7 @@ var _ = Describe("API test", func() {
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
|
||||
uuid := response["uuid"].(string)
|
||||
resp := map[string]interface{}{}
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
@@ -479,7 +479,7 @@ var _ = Describe("API test", func() {
|
||||
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
@@ -503,7 +503,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]interface{}{
|
||||
Overrides: map[string]any{
|
||||
"backend": "llama",
|
||||
},
|
||||
})
|
||||
@@ -520,7 +520,7 @@ var _ = Describe("API test", func() {
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["backend"]).To(Equal("llama"))
|
||||
@@ -529,7 +529,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]interface{}{},
|
||||
Overrides: map[string]any{},
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
@@ -544,7 +544,7 @@ var _ = Describe("API test", func() {
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
@@ -586,7 +586,7 @@ parameters:
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]interface{}{}
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
@@ -601,7 +601,7 @@ parameters:
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
content := map[string]any{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["name"]).To(Equal("test-import-model"))
|
||||
@@ -657,7 +657,7 @@ parameters:
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]interface{}{}
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
@@ -1248,7 +1248,7 @@ parameters:
|
||||
Context("Agent Jobs", Label("agent-jobs"), func() {
|
||||
It("creates and manages tasks", func() {
|
||||
// Create a task
|
||||
taskBody := map[string]interface{}{
|
||||
taskBody := map[string]any{
|
||||
"name": "Test Task",
|
||||
"description": "Test Description",
|
||||
"model": "testmodel.ggml",
|
||||
@@ -1256,7 +1256,7 @@ parameters:
|
||||
"enabled": true,
|
||||
}
|
||||
|
||||
var createResp map[string]interface{}
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(createResp["id"]).ToNot(BeEmpty())
|
||||
@@ -1302,20 +1302,20 @@ parameters:
|
||||
|
||||
It("executes and monitors jobs", func() {
|
||||
// Create a task first
|
||||
taskBody := map[string]interface{}{
|
||||
taskBody := map[string]any{
|
||||
"name": "Job Test Task",
|
||||
"model": "testmodel.ggml",
|
||||
"prompt": "Say hello",
|
||||
"enabled": true,
|
||||
}
|
||||
|
||||
var createResp map[string]interface{}
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
taskID := createResp["id"].(string)
|
||||
|
||||
// Execute a job
|
||||
jobBody := map[string]interface{}{
|
||||
jobBody := map[string]any{
|
||||
"task_id": taskID,
|
||||
"parameters": map[string]string{},
|
||||
}
|
||||
@@ -1357,14 +1357,14 @@ parameters:
|
||||
|
||||
It("executes task by name", func() {
|
||||
// Create a task with a specific name
|
||||
taskBody := map[string]interface{}{
|
||||
taskBody := map[string]any{
|
||||
"name": "Named Task",
|
||||
"model": "testmodel.ggml",
|
||||
"prompt": "Hello",
|
||||
"enabled": true,
|
||||
}
|
||||
|
||||
var createResp map[string]interface{}
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
||||
@@ -516,6 +516,17 @@ func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Node self-service endpoints — authenticated via registration token, not global auth.
|
||||
// Only exempt the specific known endpoints, not the entire prefix.
|
||||
if strings.HasPrefix(path, "/api/node/") {
|
||||
if path == "/api/node/register" ||
|
||||
strings.HasSuffix(path, "/heartbeat") ||
|
||||
strings.HasSuffix(path, "/drain") ||
|
||||
strings.HasSuffix(path, "/deregister") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check configured exempt paths
|
||||
for _, p := range appConfig.PathWithoutAuth {
|
||||
if strings.HasPrefix(path, p) {
|
||||
@@ -540,6 +551,14 @@ func isAPIPath(path string) bool {
|
||||
strings.HasPrefix(path, "/system") ||
|
||||
strings.HasPrefix(path, "/ws/") ||
|
||||
strings.HasPrefix(path, "/generated-") ||
|
||||
strings.HasPrefix(path, "/chat/") ||
|
||||
strings.HasPrefix(path, "/completions") ||
|
||||
strings.HasPrefix(path, "/edits") ||
|
||||
strings.HasPrefix(path, "/embeddings") ||
|
||||
strings.HasPrefix(path, "/audio/") ||
|
||||
strings.HasPrefix(path, "/images/") ||
|
||||
strings.HasPrefix(path, "/messages") ||
|
||||
strings.HasPrefix(path, "/responses") ||
|
||||
path == "/metrics"
|
||||
}
|
||||
|
||||
|
||||
@@ -9,24 +9,25 @@ import (
|
||||
|
||||
// Auth provider constants.
|
||||
const (
|
||||
ProviderLocal = "local"
|
||||
ProviderGitHub = "github"
|
||||
ProviderOIDC = "oidc"
|
||||
ProviderLocal = "local"
|
||||
ProviderGitHub = "github"
|
||||
ProviderOIDC = "oidc"
|
||||
ProviderAgentWorker = "agent-worker"
|
||||
)
|
||||
|
||||
// User represents an authenticated user.
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Email string `gorm:"size:255;index"`
|
||||
Name string `gorm:"size:255"`
|
||||
AvatarURL string `gorm:"size:512"`
|
||||
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
|
||||
Subject string `gorm:"size:255"` // provider-specific user ID
|
||||
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Email string `gorm:"size:255;index"`
|
||||
Name string `gorm:"size:255"`
|
||||
AvatarURL string `gorm:"size:512"`
|
||||
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
|
||||
Subject string `gorm:"size:255"` // provider-specific user ID
|
||||
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
|
||||
Role string `gorm:"size:20;default:user"`
|
||||
Status string `gorm:"size:20;default:active"` // "active", "pending"
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Session represents a user login session.
|
||||
@@ -90,16 +91,16 @@ func (p *PermissionMap) Scan(value any) error {
|
||||
|
||||
// InviteCode represents an admin-generated invitation for user registration.
|
||||
type InviteCode struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
|
||||
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
|
||||
CreatedBy string `gorm:"size:36;not null"`
|
||||
UsedBy *string `gorm:"size:36"`
|
||||
UsedAt *time.Time
|
||||
ExpiresAt time.Time `gorm:"not null;index"`
|
||||
CreatedAt time.Time
|
||||
Creator User `gorm:"foreignKey:CreatedBy"`
|
||||
Consumer *User `gorm:"foreignKey:UsedBy"`
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
|
||||
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
|
||||
CreatedBy string `gorm:"size:36;not null"`
|
||||
UsedBy *string `gorm:"size:36"`
|
||||
UsedAt *time.Time
|
||||
ExpiresAt time.Time `gorm:"not null;index"`
|
||||
CreatedAt time.Time
|
||||
Creator User `gorm:"foreignKey:CreatedBy"`
|
||||
Consumer *User `gorm:"foreignKey:UsedBy"`
|
||||
}
|
||||
|
||||
// ModelAllowlist controls which models a user can access.
|
||||
|
||||
@@ -33,24 +33,24 @@ const (
|
||||
FeatureMCPJobs = "mcp_jobs"
|
||||
|
||||
// General features (default OFF for new users)
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
FeatureQuantization = "quantization"
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
FeatureQuantization = "quantization"
|
||||
|
||||
// API features (default ON for new users)
|
||||
FeatureChat = "chat"
|
||||
FeatureImages = "images"
|
||||
FeatureAudioSpeech = "audio_speech"
|
||||
FeatureChat = "chat"
|
||||
FeatureImages = "images"
|
||||
FeatureAudioSpeech = "audio_speech"
|
||||
FeatureAudioTranscription = "audio_transcription"
|
||||
FeatureVAD = "vad"
|
||||
FeatureDetection = "detection"
|
||||
FeatureVideo = "video"
|
||||
FeatureEmbeddings = "embeddings"
|
||||
FeatureSound = "sound"
|
||||
FeatureRealtime = "realtime"
|
||||
FeatureRerank = "rerank"
|
||||
FeatureTokenize = "tokenize"
|
||||
FeatureMCP = "mcp"
|
||||
FeatureStores = "stores"
|
||||
FeatureVAD = "vad"
|
||||
FeatureDetection = "detection"
|
||||
FeatureVideo = "video"
|
||||
FeatureEmbeddings = "embeddings"
|
||||
FeatureSound = "sound"
|
||||
FeatureRealtime = "realtime"
|
||||
FeatureRerank = "rerank"
|
||||
FeatureTokenize = "tokenize"
|
||||
FeatureMCP = "mcp"
|
||||
FeatureStores = "stores"
|
||||
)
|
||||
|
||||
// AgentFeatures lists agent-related features (default OFF).
|
||||
|
||||
@@ -24,14 +24,14 @@ type QuotaRule struct {
|
||||
|
||||
// QuotaStatus is returned to clients with current usage included.
|
||||
type QuotaStatus struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
MaxRequests *int64 `json:"max_requests"`
|
||||
MaxTotalTokens *int64 `json:"max_total_tokens"`
|
||||
Window string `json:"window"`
|
||||
CurrentRequests int64 `json:"current_requests"`
|
||||
CurrentTokens int64 `json:"current_total_tokens"`
|
||||
ResetsAt string `json:"resets_at,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
MaxRequests *int64 `json:"max_requests"`
|
||||
MaxTotalTokens *int64 `json:"max_total_tokens"`
|
||||
Window string `json:"window"`
|
||||
CurrentRequests int64 `json:"current_requests"`
|
||||
CurrentTokens int64 `json:"current_total_tokens"`
|
||||
ResetsAt string `json:"resets_at,omitempty"`
|
||||
}
|
||||
|
||||
// ── CRUD ──
|
||||
@@ -209,9 +209,9 @@ func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) {
|
||||
var quotaCache = newQuotaCacheStore()
|
||||
|
||||
type quotaCacheStore struct {
|
||||
mu sync.RWMutex
|
||||
rules map[string]cachedRules // userID -> rules
|
||||
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
|
||||
mu sync.RWMutex
|
||||
rules map[string]cachedRules // userID -> rules
|
||||
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
|
||||
}
|
||||
|
||||
type cachedRules struct {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
sessionDuration = 30 * 24 * time.Hour // 30 days
|
||||
sessionIDBytes = 32 // 32 bytes = 64 hex chars
|
||||
sessionIDBytes = 32 // 32 bytes = 64 hex chars
|
||||
sessionCookie = "session"
|
||||
sessionRotationInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
@@ -10,15 +10,15 @@ import (
|
||||
|
||||
// UsageRecord represents a single API request's token usage.
|
||||
type UsageRecord struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
Model string `gorm:"size:255;index"`
|
||||
Endpoint string `gorm:"size:255"`
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
Model string `gorm:"size:255;index"`
|
||||
Endpoint string `gorm:"size:255"`
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
TotalTokens int64
|
||||
Duration int64 // milliseconds
|
||||
Duration int64 // milliseconds
|
||||
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
||||
}
|
||||
|
||||
@@ -127,10 +127,10 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", model, user_id, user_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
Select(bucketExpr + ", model, user_id, user_name, " +
|
||||
"SUM(prompt_tokens) as prompt_tokens, " +
|
||||
"SUM(completion_tokens) as completion_tokens, " +
|
||||
"SUM(total_tokens) as total_tokens, " +
|
||||
"COUNT(*) as request_count").
|
||||
Group("bucket, model, user_id, user_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
@@ -36,7 +36,7 @@ var _ = Describe("Usage", func() {
|
||||
db := testDB()
|
||||
|
||||
// Insert records for two users
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-a",
|
||||
UserName: "Alice",
|
||||
|
||||
@@ -3,7 +3,6 @@ package anthropic
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -25,7 +24,7 @@ import (
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -52,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
funcs, shouldUseFn := convertAnthropicTools(input, cfg)
|
||||
|
||||
// MCP injection: prompts, resources, and tools
|
||||
var mcpToolInfos []mcpTools.MCPToolInfo
|
||||
var mcpExecutor mcpTools.ToolExecutor
|
||||
mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata)
|
||||
mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata)
|
||||
mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata)
|
||||
@@ -60,76 +59,29 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") {
|
||||
remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML()
|
||||
if mcpErr == nil {
|
||||
mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, mcpServers)
|
||||
|
||||
// Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode)
|
||||
namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers)
|
||||
if sessErr == nil && len(namedSessions) > 0 {
|
||||
// Prompt injection
|
||||
if mcpPromptName != "" {
|
||||
prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions)
|
||||
if discErr == nil {
|
||||
promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs)
|
||||
if getErr == nil {
|
||||
var injected []schema.Message
|
||||
for _, pm := range promptMsgs {
|
||||
injected = append(injected, schema.Message{
|
||||
Role: string(pm.Role),
|
||||
Content: mcpTools.PromptMessageToText(pm),
|
||||
})
|
||||
}
|
||||
openAIMessages = append(injected, openAIMessages...)
|
||||
xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected))
|
||||
} else {
|
||||
xlog.Error("Failed to get MCP prompt", "error", getErr)
|
||||
}
|
||||
}
|
||||
mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs)
|
||||
if mcpCtx != nil {
|
||||
openAIMessages = append(mcpCtx.PromptMessages, openAIMessages...)
|
||||
mcpTools.AppendResourceSuffix(openAIMessages, mcpCtx.ResourceSuffix)
|
||||
}
|
||||
}
|
||||
|
||||
// Resource injection
|
||||
if len(mcpResourceURIs) > 0 {
|
||||
resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions)
|
||||
if discErr == nil {
|
||||
var resourceTexts []string
|
||||
for _, uri := range mcpResourceURIs {
|
||||
content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri)
|
||||
if readErr != nil {
|
||||
xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri)
|
||||
continue
|
||||
}
|
||||
name := uri
|
||||
for _, r := range resources {
|
||||
if r.URI == uri {
|
||||
name = r.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content))
|
||||
}
|
||||
if len(resourceTexts) > 0 && len(openAIMessages) > 0 {
|
||||
lastIdx := len(openAIMessages) - 1
|
||||
suffix := "\n\n" + strings.Join(resourceTexts, "\n\n")
|
||||
switch ct := openAIMessages[lastIdx].Content.(type) {
|
||||
case string:
|
||||
openAIMessages[lastIdx].Content = ct + suffix
|
||||
default:
|
||||
openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix)
|
||||
}
|
||||
xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool injection
|
||||
if len(mcpServers) > 0 {
|
||||
discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions)
|
||||
if discErr == nil {
|
||||
mcpToolInfos = discovered
|
||||
for _, ti := range mcpToolInfos {
|
||||
funcs = append(funcs, ti.Function)
|
||||
}
|
||||
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs))
|
||||
} else {
|
||||
xlog.Error("Failed to discover MCP tools", "error", discErr)
|
||||
// Tool injection via executor
|
||||
if mcpExecutor.HasTools() {
|
||||
mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context())
|
||||
if discErr == nil {
|
||||
for _, fn := range mcpFuncs {
|
||||
funcs = append(funcs, fn)
|
||||
}
|
||||
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs))
|
||||
} else {
|
||||
xlog.Error("Failed to discover MCP tools", "error", discErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -177,19 +129,19 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||
|
||||
if input.Stream {
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
}
|
||||
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
|
||||
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||
mcpMaxIterations := 10
|
||||
if cfg.Agent.MaxIterations > 0 {
|
||||
mcpMaxIterations = cfg.Agent.MaxIterations
|
||||
}
|
||||
hasMCPTools := len(mcpToolInfos) > 0
|
||||
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
|
||||
|
||||
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
||||
// Re-template on each MCP iteration since messages may have changed
|
||||
@@ -227,7 +179,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
if hasMCPTools && shouldUseFn && len(toolCalls) > 0 {
|
||||
var hasMCPCalls bool
|
||||
for _, tc := range toolCalls {
|
||||
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
|
||||
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
|
||||
hasMCPCalls = true
|
||||
break
|
||||
}
|
||||
@@ -257,13 +209,12 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
|
||||
// Execute each MCP tool call and append results
|
||||
for _, tc := range assistantMsg.ToolCalls {
|
||||
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
|
||||
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
|
||||
continue
|
||||
}
|
||||
xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
||||
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
|
||||
c.Request().Context(), mcpToolInfos,
|
||||
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||
toolResult, toolErr := mcpExecutor.ExecuteTool(
|
||||
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||
)
|
||||
if toolErr != nil {
|
||||
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
||||
@@ -290,10 +241,10 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
if shouldUseFn && len(toolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for _, tc := range toolCalls {
|
||||
var inputArgs map[string]interface{}
|
||||
var inputArgs map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
|
||||
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
|
||||
inputArgs = map[string]interface{}{"raw": tc.Arguments}
|
||||
inputArgs = map[string]any{"raw": tc.Arguments}
|
||||
}
|
||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
@@ -316,9 +267,9 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
|
||||
}
|
||||
for i, fc := range parsed {
|
||||
var inputArgs map[string]interface{}
|
||||
var inputArgs map[string]any
|
||||
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
|
||||
inputArgs = map[string]interface{}{"raw": fc.Arguments}
|
||||
inputArgs = map[string]any{"raw": fc.Arguments}
|
||||
}
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
@@ -365,7 +316,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
||||
}
|
||||
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
@@ -388,7 +339,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if cfg.Agent.MaxIterations > 0 {
|
||||
mcpMaxIterations = cfg.Agent.MaxIterations
|
||||
}
|
||||
hasMCPTools := len(mcpToolInfos) > 0
|
||||
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
|
||||
|
||||
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
||||
// Re-template on MCP iterations
|
||||
@@ -483,7 +434,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "error",
|
||||
Error: &schema.AnthropicError{
|
||||
Type: "api_error",
|
||||
Message: fmt.Sprintf("model inference failed: %v", err),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Also check chat deltas for tool calls
|
||||
@@ -495,7 +453,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if hasMCPTools && len(collectedToolCalls) > 0 {
|
||||
var hasMCPCalls bool
|
||||
for _, tc := range collectedToolCalls {
|
||||
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
|
||||
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
|
||||
hasMCPCalls = true
|
||||
break
|
||||
}
|
||||
@@ -525,13 +483,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
|
||||
// Execute MCP tool calls
|
||||
for _, tc := range assistantMsg.ToolCalls {
|
||||
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
|
||||
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
|
||||
continue
|
||||
}
|
||||
xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
||||
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
|
||||
c.Request().Context(), mcpToolInfos,
|
||||
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||
toolResult, toolErr := mcpExecutor.ExecuteTool(
|
||||
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||
)
|
||||
if toolErr != nil {
|
||||
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
||||
@@ -686,7 +643,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
case string:
|
||||
openAIMsg.StringContent = content
|
||||
openAIMsg.Content = content
|
||||
case []interface{}:
|
||||
case []any:
|
||||
// Handle array of content blocks
|
||||
var textContent string
|
||||
var stringImages []string
|
||||
@@ -694,7 +651,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
toolCallIndex := 0
|
||||
|
||||
for _, block := range content {
|
||||
if blockMap, ok := block.(map[string]interface{}); ok {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
@@ -703,7 +660,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
}
|
||||
case "image":
|
||||
// Handle image content
|
||||
if source, ok := blockMap["source"].(map[string]interface{}); ok {
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
mediaType, _ := source["media_type"].(string)
|
||||
@@ -718,14 +675,14 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
toolID, _ := blockMap["id"].(string)
|
||||
toolName, _ := blockMap["name"].(string)
|
||||
toolInput := blockMap["input"]
|
||||
|
||||
|
||||
// Serialize input to JSON string
|
||||
inputJSON, err := json.Marshal(toolInput)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to marshal tool input", "error", err)
|
||||
inputJSON = []byte("{}")
|
||||
}
|
||||
|
||||
|
||||
toolCalls = append(toolCalls, schema.ToolCall{
|
||||
Index: toolCallIndex,
|
||||
ID: toolID,
|
||||
@@ -745,16 +702,16 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil {
|
||||
isError = *isErrorPtr
|
||||
}
|
||||
|
||||
|
||||
var resultText string
|
||||
if resultContent, ok := blockMap["content"]; ok {
|
||||
switch rc := resultContent.(type) {
|
||||
case string:
|
||||
resultText = rc
|
||||
case []interface{}:
|
||||
case []any:
|
||||
// Array of content blocks
|
||||
for _, cb := range rc {
|
||||
if cbMap, ok := cb.(map[string]interface{}); ok {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultText += text
|
||||
@@ -764,7 +721,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Add tool result as a tool role message
|
||||
// We need to handle this differently - create a new message
|
||||
if msg.Role == "user" {
|
||||
@@ -781,7 +738,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||
openAIMsg.StringContent = textContent
|
||||
openAIMsg.Content = textContent
|
||||
openAIMsg.StringImages = stringImages
|
||||
|
||||
|
||||
// Add tool calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
openAIMsg.ToolCalls = toolCalls
|
||||
@@ -799,7 +756,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
if len(input.Tools) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
|
||||
var funcs functions.Functions
|
||||
for _, tool := range input.Tools {
|
||||
f := functions.Function{
|
||||
@@ -809,7 +766,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
}
|
||||
funcs = append(funcs, f)
|
||||
}
|
||||
|
||||
|
||||
// Handle tool_choice
|
||||
if input.ToolChoice != nil {
|
||||
switch tc := input.ToolChoice.(type) {
|
||||
@@ -823,7 +780,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
return nil, false
|
||||
}
|
||||
// "auto" is the default - let model decide
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
// Specific tool selection: {"type": "tool", "name": "tool_name"}
|
||||
if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
|
||||
if name, ok := tc["name"].(string); ok {
|
||||
@@ -833,6 +790,6 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package explorer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
|
||||
func Dashboard() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
summary := map[string]any{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
@@ -61,8 +62,8 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
// order by number of clusters
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
||||
slices.SortFunc(results, func(a, b Network) int {
|
||||
return cmp.Compare(len(b.Clusters), len(a.Clusters))
|
||||
})
|
||||
|
||||
return c.JSON(http.StatusOK, results)
|
||||
@@ -73,36 +74,36 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := new(AddNetworkRequest)
|
||||
if err := c.Bind(request); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if request.Token == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token is required"})
|
||||
}
|
||||
|
||||
if request.Name == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Name is required"})
|
||||
}
|
||||
|
||||
if request.Description == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"})
|
||||
}
|
||||
|
||||
// TODO: check if token is valid, otherwise reject
|
||||
// try to decode the token from base64
|
||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
if _, exists := db.Get(request.Token); exists {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token already exists"})
|
||||
}
|
||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "Cannot add token"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
||||
return c.JSON(http.StatusOK, map[string]any{"message": "Token added"})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
)
|
||||
|
||||
// getJobService returns the job service for the current user.
|
||||
// Falls back to the global service when no user is authenticated.
|
||||
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService {
|
||||
func getJobService(app *application.Application, c echo.Context) *agentpool.AgentJobService {
|
||||
userID := getUserID(c)
|
||||
if userID == "" {
|
||||
return app.AgentJobService()
|
||||
@@ -54,7 +55,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
|
||||
if err.Error() == "task not found: "+id {
|
||||
if errors.Is(err, agentpool.ErrTaskNotFound) {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
@@ -68,7 +69,7 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := getJobService(app, c).DeleteTask(id); err != nil {
|
||||
if err.Error() == "task not found: "+id {
|
||||
if errors.Is(err, agentpool.ErrTaskNotFound) {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
@@ -244,7 +245,7 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := getJobService(app, c).CancelJob(id); err != nil {
|
||||
if err.Error() == "job not found: "+id {
|
||||
if errors.Is(err, agentpool.ErrJobNotFound) {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
@@ -258,7 +259,7 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := getJobService(app, c).DeleteJob(id); err != nil {
|
||||
if err.Error() == "job not found: "+id {
|
||||
if errors.Is(err, agentpool.ErrJobNotFound) {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
@@ -275,7 +276,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
if c.Request().ContentLength > 0 {
|
||||
if err := c.Bind(¶ms); err != nil {
|
||||
body := make(map[string]interface{})
|
||||
body := make(map[string]any)
|
||||
if err := c.Bind(&body); err == nil {
|
||||
params = make(map[string]string)
|
||||
for k, v := range body {
|
||||
|
||||
@@ -2,6 +2,7 @@ package localai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,8 +11,9 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/services/agents"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
@@ -50,55 +52,105 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Check if this model name is an agent
|
||||
ag := svc.GetAgent(req.Model)
|
||||
if ag == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// This is an agent — handle the request directly
|
||||
// Check if this model name is an agent — try in-process agent first,
|
||||
// fall back to config lookup (covers distributed mode where agents
|
||||
// don't run in-process).
|
||||
messages := parseInputToMessages(req.Input)
|
||||
if len(messages) == 0 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": "invalid_request_error",
|
||||
"message": "no input messages provided",
|
||||
},
|
||||
})
|
||||
userID := effectiveUserID(c)
|
||||
ag := svc.GetAgent(req.Model)
|
||||
if ag == nil && svc.GetAgentConfigForUser(userID, req.Model) == nil {
|
||||
return next(c) // not an agent
|
||||
}
|
||||
|
||||
jobOptions := []coreTypes.JobOption{
|
||||
coreTypes.WithConversationHistory(messages),
|
||||
// Extract the last user message for the executor
|
||||
var userMessage string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
userMessage = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
res := ag.Ask(jobOptions...)
|
||||
var responseText string
|
||||
|
||||
if res == nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": "server_error",
|
||||
"message": "agent request failed or was cancelled",
|
||||
},
|
||||
})
|
||||
}
|
||||
if res.Error != nil {
|
||||
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": "server_error",
|
||||
"message": res.Error.Error(),
|
||||
},
|
||||
if ag != nil {
|
||||
// Local mode: use LocalAGI agent directly
|
||||
jobOptions := []coreTypes.JobOption{
|
||||
coreTypes.WithConversationHistory(messages),
|
||||
}
|
||||
|
||||
res := ag.Ask(jobOptions...)
|
||||
if res == nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": "server_error",
|
||||
"message": "agent request failed or was cancelled",
|
||||
},
|
||||
})
|
||||
}
|
||||
if res.Error != nil {
|
||||
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": "server_error",
|
||||
"message": res.Error.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
responseText = res.Response
|
||||
} else {
|
||||
// Distributed mode: dispatch via NATS + wait for response synchronously
|
||||
var bridge *agents.EventBridge
|
||||
if d := app.Distributed(); d != nil {
|
||||
bridge = d.AgentBridge
|
||||
}
|
||||
if bridge == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Subscribe BEFORE dispatching so we never miss a fast response
|
||||
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
responseCh := make(chan string, 1)
|
||||
sub, err := bridge.SubscribeEvents(req.Model, userID, func(evt agents.AgentEvent) {
|
||||
if evt.EventType == "json_message" && evt.Sender == "agent" {
|
||||
responseCh <- evt.Content
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{"type": "server_error", "message": "failed to subscribe to agent events"},
|
||||
})
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Now dispatch via ChatForUser (publishes to NATS)
|
||||
_, err = svc.ChatForUser(userID, req.Model, userMessage)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||
"error": map[string]string{"type": "server_error", "message": err.Error()},
|
||||
})
|
||||
}
|
||||
|
||||
select {
|
||||
case responseText = <-responseCh:
|
||||
// Got the response
|
||||
case <-ctx.Done():
|
||||
return c.JSON(http.StatusGatewayTimeout, map[string]any{
|
||||
"error": map[string]string{"type": "server_error", "message": "agent response timeout"},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("resp_%s", uuid.New().String())
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"id": id,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": req.Model,
|
||||
"id": id,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": req.Model,
|
||||
"previous_response_id": nil,
|
||||
"output": []any{
|
||||
map[string]any{
|
||||
@@ -109,7 +161,7 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": res.Response,
|
||||
"text": responseText,
|
||||
"annotations": []any{},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
skillsManager "github.com/mudler/LocalAI/core/services/skills"
|
||||
skilldomain "github.com/mudler/skillserver/pkg/domain"
|
||||
)
|
||||
|
||||
@@ -41,27 +42,48 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
|
||||
return out
|
||||
}
|
||||
|
||||
// getSkillManager returns a SkillManager for the request's user.
|
||||
func getSkillManager(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
return svc.SkillManagerForUser(userID)
|
||||
}
|
||||
|
||||
func getSkillManagerEffective(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
return svc.SkillManagerForUser(userID)
|
||||
}
|
||||
|
||||
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
skills, err := svc.ListSkillsForUser(userID)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skills, err := mgr.List()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
svc := app.AgentPoolService()
|
||||
usm := svc.UserServicesManager()
|
||||
if usm != nil {
|
||||
userIDs, _ := usm.ListAllUserIDs()
|
||||
userGroups := map[string]any{}
|
||||
userID := getUserID(c)
|
||||
for _, uid := range userIDs {
|
||||
if uid == userID {
|
||||
continue
|
||||
}
|
||||
userSkills, err := svc.ListSkillsForUser(uid)
|
||||
if err != nil || len(userSkills) == 0 {
|
||||
uidMgr, mgrErr := svc.SkillManagerForUser(uid)
|
||||
if mgrErr != nil {
|
||||
continue
|
||||
}
|
||||
userSkills, listErr := uidMgr.List()
|
||||
if listErr != nil || len(userSkills) == 0 {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
|
||||
@@ -76,25 +98,28 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, skillsToResponses(skills))
|
||||
return c.JSON(http.StatusOK, map[string]any{"skills": skillsToResponses(skills)})
|
||||
}
|
||||
}
|
||||
|
||||
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
cfg := svc.GetSkillsConfigForUser(userID)
|
||||
return c.JSON(http.StatusOK, cfg)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusOK, map[string]string{})
|
||||
}
|
||||
return c.JSON(http.StatusOK, mgr.GetConfig())
|
||||
}
|
||||
}
|
||||
|
||||
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
query := c.QueryParam("q")
|
||||
skills, err := svc.SearchSkillsForUser(userID, query)
|
||||
skills, err := mgr.Search(query)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -104,8 +129,10 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
var payload struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
@@ -118,7 +145,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
skill, err := mgr.Create(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
|
||||
@@ -131,9 +158,11 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
skill, err := svc.GetSkillForUser(userID, c.Param("name"))
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := mgr.Get(c.Param("name"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -143,8 +172,10 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
var payload struct {
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
@@ -156,7 +187,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
skill, err := mgr.Update(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -169,9 +200,11 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil {
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := mgr.Delete(c.Param("name")); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -180,10 +213,12 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
name := c.Param("*")
|
||||
data, err := svc.ExportSkillForUser(userID, name)
|
||||
data, err := mgr.Export(name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -195,8 +230,10 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
||||
@@ -210,7 +247,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.ImportSkillForUser(userID, data)
|
||||
skill, err := mgr.Import(data)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -222,9 +259,11 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name"))
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
resources, skill, err := mgr.ListResources(c.Param("name"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -260,9 +299,11 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*"))
|
||||
mgr, err := getSkillManagerEffective(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
content, info, err := mgr.GetResource(c.Param("name"), c.Param("*"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -281,10 +322,12 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
file, err := c.FormFile("file")
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
file, fileErr := c.FormFile("file")
|
||||
if fileErr != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
|
||||
}
|
||||
path := c.FormValue("path")
|
||||
@@ -300,7 +343,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil {
|
||||
if err := mgr.CreateResource(c.Param("name"), path, data); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"path": path})
|
||||
@@ -309,15 +352,17 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
|
||||
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
var payload struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
||||
if err := mgr.UpdateResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -326,9 +371,11 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
|
||||
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil {
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := mgr.DeleteResource(c.Param("name"), c.Param("*")); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -339,9 +386,11 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
|
||||
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
repos, err := svc.ListGitReposForUser(userID)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repos, err := mgr.ListGitRepos()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -351,15 +400,17 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repo, err := svc.AddGitRepoForUser(userID, payload.URL)
|
||||
repo, err := mgr.AddGitRepo(payload.URL)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -369,8 +420,10 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
@@ -378,7 +431,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled)
|
||||
repo, err := mgr.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -391,9 +444,11 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil {
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := mgr.DeleteGitRepo(c.Param("id")); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -405,9 +460,11 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil {
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := mgr.SyncGitRepo(c.Param("id")); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
|
||||
@@ -416,9 +473,11 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
|
||||
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id"))
|
||||
mgr, err := getSkillManager(c, app)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repo, err := mgr.ToggleGitRepo(c.Param("id"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -4,20 +4,23 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
agiServices "github.com/mudler/LocalAGI/services"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/agents"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// getUserID extracts the scoped user ID from the request context.
|
||||
@@ -42,25 +45,39 @@ func wantsAllUsers(c echo.Context) bool {
|
||||
}
|
||||
|
||||
// effectiveUserID returns the user ID to scope operations to.
|
||||
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's
|
||||
// resources. Non-admin callers always get their own ID regardless of query params.
|
||||
// SECURITY: Only admins and agent-worker service accounts may supply
|
||||
// ?user_id=<id> to operate on another user's resources. Agent-worker users are
|
||||
// created exclusively server-side during node registration and need to access
|
||||
// collections on behalf of the user whose agent they are executing.
|
||||
// Regular callers always get their own ID regardless of query params.
|
||||
func effectiveUserID(c echo.Context) string {
|
||||
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) {
|
||||
if targetUID := c.QueryParam("user_id"); targetUID != "" && canImpersonateUser(c) {
|
||||
if callerID := getUserID(c); callerID != targetUID {
|
||||
xlog.Info("User impersonation", "caller", callerID, "target", targetUID, "path", c.Path())
|
||||
}
|
||||
return targetUID
|
||||
}
|
||||
return getUserID(c)
|
||||
}
|
||||
|
||||
// canImpersonateUser returns true if the caller is allowed to use ?user_id= to
|
||||
// scope operations to another user. Allowed for admins and agent-worker service
|
||||
// accounts (ProviderAgentWorker is set server-side during node registration and
|
||||
// cannot be self-assigned).
|
||||
func canImpersonateUser(c echo.Context) bool {
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
return user.Role == auth.RoleAdmin || user.Provider == auth.ProviderAgentWorker
|
||||
}
|
||||
|
||||
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
statuses := svc.ListAgentsForUser(userID)
|
||||
agents := make([]string, 0, len(statuses))
|
||||
for name := range statuses {
|
||||
agents = append(agents, name)
|
||||
}
|
||||
sort.Strings(agents)
|
||||
agents := slices.Sorted(maps.Keys(statuses))
|
||||
resp := map[string]any{
|
||||
"agents": agents,
|
||||
"agentCount": len(agents),
|
||||
@@ -111,13 +128,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
ag := svc.GetAgentForUser(userID, name)
|
||||
if ag == nil {
|
||||
|
||||
statuses := svc.ListAgentsForUser(userID)
|
||||
active, exists := statuses[name]
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"active": !ag.Paused(),
|
||||
})
|
||||
return c.JSON(http.StatusOK, map[string]any{"active": active})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,9 +209,13 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
|
||||
history := svc.GetAgentStatusForUser(userID, name)
|
||||
if history == nil {
|
||||
history = &state.Status{ActionResults: []coreTypes.ActionState{}}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"Name": name,
|
||||
"History": []string{},
|
||||
})
|
||||
}
|
||||
entries := []string{}
|
||||
for i := len(history.Results()) - 1; i >= 0; i-- {
|
||||
@@ -221,10 +242,14 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
|
||||
history, err := svc.GetAgentObservablesForUser(userID, name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if history == nil {
|
||||
history = []json.RawMessage{}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"Name": name,
|
||||
"History": history,
|
||||
@@ -278,26 +303,30 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
manager := svc.GetSSEManagerForUser(userID, name)
|
||||
if manager == nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
return services.HandleSSE(c, manager)
|
||||
}
|
||||
}
|
||||
|
||||
type agentConfigMetaResponse struct {
|
||||
state.AgentConfigMeta
|
||||
OutputsDir string `json:"OutputsDir"`
|
||||
// Try local SSE manager first
|
||||
manager := svc.GetSSEManagerForUser(userID, name)
|
||||
if manager != nil {
|
||||
return agentpool.HandleSSE(c, manager)
|
||||
}
|
||||
|
||||
// Fall back to distributed EventBridge SSE
|
||||
var bridge *agents.EventBridge
|
||||
if d := app.Distributed(); d != nil {
|
||||
bridge = d.AgentBridge
|
||||
}
|
||||
if bridge != nil {
|
||||
return bridge.HandleSSE(c, name, userID)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
}
|
||||
|
||||
func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
return c.JSON(http.StatusOK, agentConfigMetaResponse{
|
||||
AgentConfigMeta: svc.GetConfigMeta(),
|
||||
OutputsDir: svc.OutputsDir(),
|
||||
})
|
||||
return c.JSON(http.StatusOK, svc.GetConfigMetaResult())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user