mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-20 06:35:41 -04:00
Compare commits
6 Commits
v4.1.3
...
feat/fine-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8997ff6042 | ||
|
|
f1223b45b2 | ||
|
|
fa8b1a8673 | ||
|
|
3451dbdccd | ||
|
|
7b8afc9609 | ||
|
|
ae4b758a5a |
@@ -49,4 +49,3 @@ The project documentation is located in `docs/content`. When adding new features
|
||||
- **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it.
|
||||
- **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`.
|
||||
- **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly.
|
||||
- **Shortcodes**: Use `{{% notice note %}}`, `{{% notice tip %}}`, or `{{% notice warning %}}` for callout boxes. Do **not** use `{{% alert %}}` — that shortcode does not exist in this project's Hugo theme and will break the docs build.
|
||||
|
||||
3
.github/gallery-agent/agent.go
vendored
3
.github/gallery-agent/agent.go
vendored
@@ -133,7 +133,6 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
result, err := cogito.ExecuteTools(llm, fragment,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.DisableSinkState,
|
||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -407,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string {
|
||||
}
|
||||
|
||||
// Parse the response to get avatar URL
|
||||
var userInfo map[string]any
|
||||
var userInfo map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
15
.github/gallery-agent/gallery.go
vendored
15
.github/gallery-agent/gallery.go
vendored
@@ -79,20 +79,7 @@ func generateYAMLEntry(model ProcessedModel, quantization string) string {
|
||||
description = cleanTextContent(description)
|
||||
formattedDescription := formatTextContent(description)
|
||||
|
||||
// Strip name and description from config file since they are
|
||||
// already present at the gallery entry level and should not
|
||||
// appear under overrides.
|
||||
configFileContent := modelConfig.ConfigFile
|
||||
var cfgMap map[string]any
|
||||
if err := yaml.Unmarshal([]byte(configFileContent), &cfgMap); err == nil {
|
||||
delete(cfgMap, "name")
|
||||
delete(cfgMap, "description")
|
||||
if cleaned, err := yaml.Marshal(cfgMap); err == nil {
|
||||
configFileContent = string(cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
configFile := formatTextContent(configFileContent)
|
||||
configFile := formatTextContent(modelConfig.ConfigFile)
|
||||
|
||||
filesYAML, _ := yaml.Marshal(modelConfig.Files)
|
||||
|
||||
|
||||
40
.github/gallery-agent/testing.go
vendored
40
.github/gallery-agent/testing.go
vendored
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
|
||||
generator := NewSyntheticDataGenerator()
|
||||
|
||||
// Generate a random number of synthetic models (1-3)
|
||||
numModels := generator.rand.IntN(3) + 1
|
||||
numModels := generator.rand.Intn(3) + 1
|
||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||
|
||||
var models []ProcessedModel
|
||||
for range numModels {
|
||||
for i := 0; i < numModels; i++ {
|
||||
model := generator.GenerateProcessedModel()
|
||||
models = append(models, model)
|
||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
|
||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||
return &SyntheticDataGenerator{
|
||||
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||
fileTypes := []string{"model", "readme", "other"}
|
||||
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
|
||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
||||
|
||||
var path string
|
||||
var isReadme bool
|
||||
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
|
||||
|
||||
return ProcessedModelFile{
|
||||
Path: path,
|
||||
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
|
||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
||||
SHA256: g.randomSHA256(),
|
||||
IsReadme: isReadme,
|
||||
FileType: fileType,
|
||||
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||
|
||||
author := authors[g.rand.IntN(len(authors))]
|
||||
modelName := modelNames[g.rand.IntN(len(modelNames))]
|
||||
author := authors[g.rand.Intn(len(authors))]
|
||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||
|
||||
// Generate files
|
||||
numFiles := g.rand.IntN(5) + 2 // 2-6 files
|
||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
||||
files := make([]ProcessedModelFile, numFiles)
|
||||
|
||||
// Ensure at least one model file and one readme
|
||||
hasModelFile := false
|
||||
hasReadme := false
|
||||
|
||||
for i := range numFiles {
|
||||
for i := 0; i < numFiles; i++ {
|
||||
files[i] = g.GenerateProcessedModelFile()
|
||||
if files[i].FileType == "model" {
|
||||
hasModelFile = true
|
||||
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
|
||||
// Generate sample metadata
|
||||
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
||||
license := licenses[g.rand.IntN(len(licenses))]
|
||||
license := licenses[g.rand.Intn(len(licenses))]
|
||||
|
||||
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
||||
numTags := g.rand.IntN(4) + 3 // 3-6 tags
|
||||
numTags := g.rand.Intn(4) + 3 // 3-6 tags
|
||||
tags := make([]string, numTags)
|
||||
for i := range numTags {
|
||||
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
|
||||
for i := 0; i < numTags; i++ {
|
||||
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
|
||||
}
|
||||
// Remove duplicates
|
||||
tags = g.removeDuplicates(tags)
|
||||
|
||||
// Optionally include icon (50% chance)
|
||||
icon := ""
|
||||
if g.rand.IntN(2) == 0 {
|
||||
if g.rand.Intn(2) == 0 {
|
||||
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
||||
}
|
||||
|
||||
return ProcessedModel{
|
||||
ModelID: modelID,
|
||||
Author: author,
|
||||
Downloads: g.rand.IntN(1000000) + 1000,
|
||||
Downloads: g.rand.Intn(1000000) + 1000,
|
||||
LastModified: g.randomDate(),
|
||||
Files: files,
|
||||
PreferredModelFile: preferredModelFile,
|
||||
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||
const charset = "0123456789abcdef"
|
||||
b := make([]byte, 64)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) randomDate() string {
|
||||
now := time.Now()
|
||||
daysAgo := g.rand.IntN(365) // Random date within last year
|
||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||
}
|
||||
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
|
||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||
}
|
||||
|
||||
return templates[g.rand.IntN(len(templates))]
|
||||
return templates[g.rand.Intn(len(templates))]
|
||||
}
|
||||
|
||||
16
.github/workflows/backend.yml
vendored
16
.github/workflows/backend.yml
vendored
@@ -131,19 +131,6 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-llama-cpp-quantization'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "llama-cpp-quantization"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2425,9 +2412,6 @@ jobs:
|
||||
tag-suffix: "-metal-darwin-arm64-local-store"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "llama-cpp-quantization"
|
||||
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
|
||||
build-type: "mps"
|
||||
with:
|
||||
backend: ${{ matrix.backend }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
|
||||
48
.github/workflows/bump-inference-defaults.yml
vendored
48
.github/workflows/bump-inference-defaults.yml
vendored
@@ -1,48 +0,0 @@
|
||||
name: Bump inference defaults
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 06:00 UTC
|
||||
- cron: '0 6 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
bump:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Re-fetch inference defaults
|
||||
run: make generate-force
|
||||
|
||||
- name: Check for changes
|
||||
id: diff
|
||||
run: |
|
||||
if git diff --quiet core/config/inference_defaults.json; then
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.diff.outputs.changed == 'true'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
commit-message: "chore: bump inference defaults from unsloth"
|
||||
title: "chore: bump inference defaults from unsloth"
|
||||
body: |
|
||||
Auto-generated update of `core/config/inference_defaults.json` from
|
||||
[unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json).
|
||||
|
||||
This PR was created automatically by the `bump-inference-defaults` workflow.
|
||||
branch: chore/bump-inference-defaults
|
||||
delete-branch: true
|
||||
labels: automated
|
||||
2
.github/workflows/gallery-agent.yaml
vendored
2
.github/workflows/gallery-agent.yaml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
OPENAI_MODEL: Qwen3.5-2B-GGUF
|
||||
OPENAI_MODE: Qwen3.5-2B-GGUF
|
||||
OPENAI_BASE_URL: "http://localhost:8080"
|
||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
|
||||
75
.github/workflows/gh-pages.yml
vendored
75
.github/workflows/gh-pages.yml
vendored
@@ -1,75 +0,0 @@
|
||||
name: Deploy docs to GitHub Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'gallery/**'
|
||||
- 'images/**'
|
||||
- '.github/ci/modelslist.go'
|
||||
- '.github/workflows/gh-pages.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
concurrency:
|
||||
group: pages
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
HUGO_VERSION: "0.146.3"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # needed for enableGitInfo
|
||||
submodules: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: false
|
||||
|
||||
- name: Setup Hugo
|
||||
uses: peaceiris/actions-hugo@v3
|
||||
with:
|
||||
hugo-version: ${{ env.HUGO_VERSION }}
|
||||
extended: true
|
||||
|
||||
- name: Setup Pages
|
||||
id: pages
|
||||
uses: actions/configure-pages@v6
|
||||
|
||||
- name: Generate gallery
|
||||
run: go run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html
|
||||
|
||||
- name: Build site
|
||||
working-directory: docs
|
||||
run: |
|
||||
mkdir -p layouts/_default
|
||||
hugo --minify --baseURL "${{ steps.pages.outputs.base_url }}/"
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v4
|
||||
with:
|
||||
path: docs/public
|
||||
|
||||
deploy:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v5
|
||||
84
.github/workflows/test-extra.yml
vendored
84
.github/workflows/test-extra.yml
vendored
@@ -14,37 +14,6 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
detect-changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
run-all: ${{ steps.detect.outputs.run-all }}
|
||||
transformers: ${{ steps.detect.outputs.transformers }}
|
||||
rerankers: ${{ steps.detect.outputs.rerankers }}
|
||||
diffusers: ${{ steps.detect.outputs.diffusers }}
|
||||
coqui: ${{ steps.detect.outputs.coqui }}
|
||||
moonshine: ${{ steps.detect.outputs.moonshine }}
|
||||
pocket-tts: ${{ steps.detect.outputs.pocket-tts }}
|
||||
qwen-tts: ${{ steps.detect.outputs.qwen-tts }}
|
||||
qwen-asr: ${{ steps.detect.outputs.qwen-asr }}
|
||||
nemo: ${{ steps.detect.outputs.nemo }}
|
||||
voxcpm: ${{ steps.detect.outputs.voxcpm }}
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
- name: Install dependencies
|
||||
run: bun add js-yaml @octokit/core
|
||||
- name: Detect changed backends
|
||||
id: detect
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_EVENT_PATH: ${{ github.event_path }}
|
||||
run: bun run scripts/changed-backends.js
|
||||
|
||||
# Requires CUDA
|
||||
# tests-chatterbox-tts:
|
||||
# runs-on: ubuntu-latest
|
||||
@@ -68,8 +37,6 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox
|
||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox test
|
||||
tests-transformers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.transformers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -91,8 +58,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/transformers
|
||||
make --jobs=5 --output-sync=target -C backend/python/transformers test
|
||||
tests-rerankers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.rerankers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -115,8 +80,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/rerankers test
|
||||
|
||||
tests-diffusers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.diffusers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -266,8 +229,6 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/vllm test
|
||||
|
||||
tests-coqui:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.coqui == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -287,8 +248,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/coqui
|
||||
make --jobs=5 --output-sync=target -C backend/python/coqui test
|
||||
tests-moonshine:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.moonshine == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -308,8 +267,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
||||
tests-pocket-tts:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.pocket-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -329,8 +286,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
|
||||
tests-qwen-tts:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.qwen-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -372,8 +327,6 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech
|
||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech test
|
||||
tests-qwen-asr:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.qwen-asr == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -393,8 +346,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr test
|
||||
tests-nemo:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.nemo == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -414,8 +365,6 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/nemo
|
||||
make --jobs=5 --output-sync=target -C backend/python/nemo test
|
||||
tests-voxcpm:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.voxcpm == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -434,38 +383,7 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||
tests-llama-cpp-quantization:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.llama-cpp-quantization == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl git python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Build llama-quantize from llama.cpp
|
||||
run: |
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp
|
||||
cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF
|
||||
cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc)
|
||||
sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/
|
||||
- name: Install backend
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization
|
||||
- name: Test llama-cpp-quantization
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test
|
||||
tests-acestep-cpp:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.acestep-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -496,8 +414,6 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test
|
||||
tests-voxtral:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.voxtral == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
go-version: ['1.25.x']
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
@@ -179,7 +179,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
go-version: ['1.25.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
|
||||
@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
|
||||
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
||||
FROM requirements-drivers AS build-requirements
|
||||
|
||||
ARG GO_VERSION=1.26.0
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG TARGETARCH
|
||||
@@ -319,6 +319,7 @@ COPY ./.git ./.git
|
||||
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
||||
COPY ./pkg/grpc ./pkg/grpc
|
||||
COPY ./pkg/utils ./pkg/utils
|
||||
COPY ./pkg/langchain ./pkg/langchain
|
||||
|
||||
RUN ls -l ./
|
||||
RUN make protogen-go
|
||||
|
||||
18
Makefile
18
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui
|
||||
|
||||
## Build:
|
||||
|
||||
build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project
|
||||
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project
|
||||
$(info ${GREEN}I local-ai build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||
@@ -398,16 +398,6 @@ protogen-go: protoc install-go-tools
|
||||
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
backend/backend.proto
|
||||
|
||||
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: generate
|
||||
generate: core/config/inference_defaults.json ## Ensure inference defaults exist
|
||||
|
||||
.PHONY: generate-force
|
||||
generate-force: ## Re-fetch inference defaults from unsloth (always)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: protogen-go-clean
|
||||
protogen-go-clean:
|
||||
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
||||
@@ -585,7 +575,6 @@ BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -644,13 +633,12 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
361
README.md
361
README.md
@@ -5,17 +5,35 @@
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/go-skynet/LocalAI/fork" target="blank">
|
||||
<img src="https://img.shields.io/github/forks/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI forks"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
|
||||
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/pulls" target="blank">
|
||||
<img src="https://img.shields.io/github/issues-pr/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI pull-requests"/>
|
||||
</a>
|
||||
<a href='https://github.com/go-skynet/LocalAI/releases'>
|
||||
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE" target="blank">
|
||||
<img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://hub.docker.com/r/localai/localai" target="blank">
|
||||
<img src="https://img.shields.io/badge/dockerhub-images-important.svg?logo=Docker" alt="LocalAI Docker hub"/>
|
||||
</a>
|
||||
<a href="https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest" target="blank">
|
||||
<img src="https://img.shields.io/badge/quay.io-images-important.svg?" alt="LocalAI Quay.io"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://twitter.com/LocalAI_API" target="blank">
|
||||
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/>
|
||||
@@ -29,183 +47,310 @@
|
||||
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
[](https://t.me/localaiofficial_bot)
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **35+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
|
||||
Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
<p align="center">
|
||||
<a href="https://github.com/mudler/LocalAI-examples" target="blank">
|
||||
<img src="https://img.shields.io/badge/📦_Examples_Repository-Browse_Ready--to--Run_Examples-blue?style=for-the-badge" alt="LocalAI Examples Repository"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
||||
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
|
||||
## Guided tour
|
||||
## Screenshots / Video
|
||||
|
||||
### Chat, Model gallery
|
||||
|
||||
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||
|
||||
<details>
|
||||
|
||||
<summary>
|
||||
Click to see more!
|
||||
</summary>
|
||||
|
||||
#### User and auth
|
||||
|
||||
https://github.com/user-attachments/assets/228fa9ad-81a3-4d43-bfb9-31557e14a36c
|
||||
|
||||
#### Agents
|
||||
### Agents
|
||||
|
||||
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||
|
||||
#### Usage metrics per user
|
||||
### Youtube video
|
||||
|
||||
https://github.com/user-attachments/assets/cbb03379-23b4-4e3d-bd26-d152f057007f
|
||||
<h1 align="center">
|
||||
<br>
|
||||
<a href="https://www.youtube.com/watch?v=PDqYhB9nNHA" target="_blank"> <img width="300" src="https://img.youtube.com/vi/PDqYhB9nNHA/0.jpg"> </a><br>
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
#### Fine-tuning and Quantization
|
||||
## 💻 Quickstart
|
||||
|
||||
https://github.com/user-attachments/assets/5ba4ace9-d3df-4795-b7d4-b0b404ea71ee
|
||||
|
||||
#### WebRTC
|
||||
|
||||
https://github.com/user-attachments/assets/ed88e34c-fed3-4b83-8a67-4716a9feeb7b
|
||||
|
||||
</details>
|
||||
|
||||
## Quickstart
|
||||
|
||||
### macOS
|
||||
### macOS Download:
|
||||
|
||||
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||
</a>
|
||||
|
||||
> **Note:** The DMG is not signed by Apple. After installing, run: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`. See [#6268](https://github.com/mudler/LocalAI/issues/6268) for details.
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
> Install the DMG and paste this code into terminal: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`
|
||||
|
||||
### Containers (Docker, podman, ...)
|
||||
|
||||
> Already ran LocalAI before? Use `docker start -i local-ai` to restart an existing container.
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
>
|
||||
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
|
||||
> - `docker start` starts an existing container that was previously created with `docker run`.
|
||||
>
|
||||
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
||||
|
||||
#### CPU only:
|
||||
#### CPU only image:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
|
||||
```
|
||||
|
||||
#### NVIDIA GPU:
|
||||
#### NVIDIA GPU Images:
|
||||
|
||||
```bash
|
||||
# CUDA 13
|
||||
# CUDA 13.0
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
|
||||
|
||||
# CUDA 12
|
||||
# CUDA 12.0
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
|
||||
|
||||
# NVIDIA Jetson ARM64 (CUDA 12, for AGX Orin and similar)
|
||||
# NVIDIA Jetson (L4T) ARM64
|
||||
# CUDA 12 (for Nvidia AGX Orin and similar platforms)
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
|
||||
|
||||
# NVIDIA Jetson ARM64 (CUDA 13, for DGX Spark)
|
||||
# CUDA 13 (for Nvidia DGX Spark)
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
|
||||
```
|
||||
|
||||
#### AMD GPU (ROCm):
|
||||
#### AMD GPU Images (ROCm):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
|
||||
```
|
||||
|
||||
#### Intel GPU (oneAPI):
|
||||
#### Intel GPU Images (oneAPI):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
||||
```
|
||||
|
||||
#### Vulkan GPU:
|
||||
#### Vulkan GPU Images:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
|
||||
```
|
||||
|
||||
### Loading models
|
||||
To load models:
|
||||
|
||||
```bash
|
||||
# From the model gallery (see available models with `local-ai models list` or at https://models.localai.io)
|
||||
# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io)
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
# From Huggingface
|
||||
# Start LocalAI with the phi-2 model directly from huggingface
|
||||
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
|
||||
# From the Ollama OCI registry
|
||||
# Install and run a model from the Ollama OCI registry
|
||||
local-ai run ollama://gemma:2b
|
||||
# From a YAML config
|
||||
# Run a model from a configuration file
|
||||
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
# From a standard OCI registry (e.g., Docker Hub)
|
||||
# Install and run a model from a standard OCI registry (e.g., Docker Hub)
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
||||
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
|
||||
## Latest News
|
||||
## 📰 Latest project news
|
||||
- March 2026: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790),[MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- February 2026: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- January 2026: **LocalAI 3.10.0** - Major release with Anthropic API support, Open Responses API for stateful agents, video & image generation suite (LTX-2), unified GPU backends, tool streaming & XML parsing, system-aware backend gallery, crash fixes for AVX-only CPUs and AMD VRAM reporting, request tracing, and new backends: **Moonshine** (ultra-fast transcription), **Pocket-TTS** (lightweight TTS). Vulkan arm64 builds now available. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0).
|
||||
- December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||
- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
||||
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
||||
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
||||
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
|
||||
- Apr 2025: Rebrand, WebUI enhancements
|
||||
- Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack.
|
||||
- Apr 2025: WebUI overhaul
|
||||
- Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images
|
||||
- Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603
|
||||
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
|
||||
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
|
||||
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
||||
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
||||
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
||||
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
|
||||
- May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/
|
||||
- May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324
|
||||
- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
|
||||
|
||||
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||
- **November 2025**: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245), [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||
- **October 2025**: [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support for agentic capabilities
|
||||
- **September 2025**: New Launcher for macOS and Linux, extended backend support for Mac and Nvidia L4T, MLX-Audio, WAN 2.2
|
||||
- **August 2025**: MLX, MLX-VLM, Diffusers, llama.cpp now supported on Apple Silicon
|
||||
- **July 2025**: All backends migrated outside the main binary — [lightweight, modular architecture](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
|
||||
For older news and full release notes, see [GitHub Releases](https://github.com/mudler/LocalAI/releases) and the [News page](https://localai.io/basics/news/).
|
||||
## 🚀 [Features](https://localai.io/features/)
|
||||
|
||||
## Features
|
||||
- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven.
|
||||
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
|
||||
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||
- 🎨 [Image generation](https://localai.io/features/image-generation)
|
||||
- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
|
||||
- ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||
- 🆕🤖 [Built-in Agents](https://localai.io/features/agents/) - Autonomous AI agents with tool use, knowledge base (RAG), skills, SSE streaming, import/export, and [Agent Hub](https://agenthub.localai.io) — powered by [LocalAGI](https://github.com/mudler/LocalAGI)
|
||||
- 🔊 Voice activity detection (Silero-VAD support)
|
||||
- 🌍 Integrated WebUI!
|
||||
|
||||
- [Text generation](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [and more](https://localai.io/model-compatibility/))
|
||||
- [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||
- [Image generation](https://localai.io/features/image-generation)
|
||||
- [OpenAI-compatible tools API](https://localai.io/features/openai-functions/)
|
||||
- [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||
- [Embeddings generation](https://localai.io/features/embeddings/)
|
||||
- [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- [Download models from Huggingface](https://localai.io/models/)
|
||||
- [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- [Object Detection](https://localai.io/features/object-detection/)
|
||||
- [Reranker API](https://localai.io/features/reranker/)
|
||||
- [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
|
||||
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
||||
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
||||
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
||||
- Voice Activity Detection (Silero-VAD)
|
||||
- Integrated WebUI
|
||||
## 🧩 Supported Backends & Acceleration
|
||||
|
||||
## Supported Backends & Acceleration
|
||||
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
|
||||
|
||||
LocalAI supports **35+ backends** including llama.cpp, vLLM, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
|
||||
### Text Generation & Language Models
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU |
|
||||
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel |
|
||||
| **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
|
||||
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
|
||||
| **vLLM Omni** | Multimodal vLLM with vision and audio | CUDA 12/13, ROCm, Intel |
|
||||
|
||||
See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
|
||||
### Audio & Speech Processing
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU |
|
||||
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **moonshine** | Ultra-fast transcription engine for low-end devices | CUDA 12/13, Metal, CPU |
|
||||
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **chatterbox** | Production-grade TTS | CUDA 12/13, CPU |
|
||||
| **piper** | Fast neural TTS system | CPU |
|
||||
| **kitten-tts** | Kitten TTS models | CPU |
|
||||
| **silero-vad** | Voice Activity Detection | CPU |
|
||||
| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
|
||||
| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **nemo** | NVIDIA NeMo framework for speech models | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **outetts** | OuteTTS with voice cloning | CUDA 12/13, CPU |
|
||||
| **faster-qwen3-tts** | Faster Qwen3 TTS | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **qwen-asr** | Qwen ASR speech recognition | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **voxcpm** | VoxCPM speech understanding | CUDA 12/13, Metal, CPU |
|
||||
| **whisperx** | Enhanced Whisper transcription | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **ace-step** | Music generation from text descriptions, lyrics, or audio samples | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
||||
|
||||
## Resources
|
||||
### Image & Video Generation
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU |
|
||||
| **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
||||
|
||||
- [Documentation](https://localai.io/)
|
||||
- [LLM fine-tuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
||||
- [Build from source](https://localai.io/basics/build/)
|
||||
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||
### Specialized AI Tasks
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU |
|
||||
| **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **local-store** | Vector database | CPU |
|
||||
| **huggingface** | HuggingFace API integration | API-based |
|
||||
|
||||
## Autonomous Development Team
|
||||
### Hardware Acceleration Matrix
|
||||
|
||||
LocalAI is helped being maintained by a team of autonomous AI agents led by an AI Scrum Master.
|
||||
| Acceleration Type | Supported Backends | Hardware Support |
|
||||
|-------------------|-------------------|------------------|
|
||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, neutts, vibevoice, pocket-tts, qwen-tts, ace-step | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, coqui, kokoro, vibevoice, pocket-tts, qwen-tts, ace-step | Intel Arc, Intel iGPUs |
|
||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, moonshine, ace-step | Apple M1/M2/M3+ |
|
||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||
| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr, ace-step | ARM64 embedded AI (AGX Orin, etc.) |
|
||||
| **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) |
|
||||
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
|
||||
|
||||
- **Live Reports**: [reports.localai.io](http://reports.localai.io)
|
||||
- **Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
||||
- **Blog Post**: [Learn about the experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
### 🔗 Community and integrations
|
||||
|
||||
Build and deploy custom containers:
|
||||
- https://github.com/sozercan/aikit
|
||||
|
||||
WebUIs:
|
||||
- https://github.com/Jirubizu/localai-admin
|
||||
- https://github.com/go-skynet/LocalAI-frontend
|
||||
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
||||
|
||||
Agentic Libraries:
|
||||
- https://github.com/mudler/cogito
|
||||
|
||||
MCPs:
|
||||
- https://github.com/mudler/MCPs
|
||||
|
||||
OS Assistant:
|
||||
|
||||
- https://github.com/mudler/Keygeist - Keygeist is an AI-powered keyboard operator that listens for key combinations and responds with AI-generated text typed directly into your Linux box.
|
||||
|
||||
Model galleries
|
||||
- https://github.com/go-skynet/model-gallery
|
||||
|
||||
Voice:
|
||||
- https://github.com/richiejp/VoxInput
|
||||
|
||||
Other:
|
||||
- Helm chart https://github.com/go-skynet/helm-charts
|
||||
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
|
||||
- Langchain: https://python.langchain.com/docs/integrations/providers/localai/
|
||||
- Terminal utility https://github.com/djcopley/ShellOracle
|
||||
- Local Smart assistant https://github.com/mudler/LocalAGI
|
||||
- Home Assistant https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-llmvision / https://github.com/loryanstrant/HA-LocalAI-Monitor
|
||||
- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
|
||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
||||
- Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot
|
||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
||||
- Another Telegram Bot https://github.com/JackBekket/Hellper
|
||||
- Auto-documentation https://github.com/JackBekket/Reflexia
|
||||
- Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper
|
||||
- Github Actions: https://github.com/marketplace/actions/start-localai
|
||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
||||
|
||||
|
||||
### 🔗 Resources
|
||||
|
||||
- [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
||||
- [How to build locally](https://localai.io/basics/build/index.html)
|
||||
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
||||
- [Projects integrating LocalAI](https://localai.io/docs/integrations/)
|
||||
- [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community)
|
||||
|
||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||
|
||||
- 🆕 [LocalAI Autonomous Dev Team Blog Post](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
|
||||
- [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/)
|
||||
- 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/)
|
||||
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/)
|
||||
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
|
||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
||||
|
||||
|
||||
## 🤖 Autonomous Development Team
|
||||
|
||||
LocalAI is now helped being maintained (for small tasks!) by a full team of autonomous AI agents led by an AI Scrum Master! This experiment demonstrates how open source projects can leverage AI agents for sustainable, long-term maintenance.
|
||||
|
||||
- **📊 Live Reports**: [Automatically generated reports](http://reports.localai.io)
|
||||
- **📋 Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
||||
- **📝 Blog Post**: [Learn about the autonomous dev team experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -221,7 +366,7 @@ If you utilize this repository, data in a downstream project, please consider ci
|
||||
howpublished = {\url{https://github.com/go-skynet/LocalAI}},
|
||||
```
|
||||
|
||||
## Sponsors
|
||||
## ❤️ Sponsors
|
||||
|
||||
> Do you find LocalAI useful?
|
||||
|
||||
@@ -240,19 +385,19 @@ A huge thank you to our generous sponsors who support this project covering CI e
|
||||
|
||||
### Individual sponsors
|
||||
|
||||
A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||
A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||
|
||||
## Star history
|
||||
## 🌟 Star history
|
||||
|
||||
[](https://star-history.com/#go-skynet/LocalAI&Date)
|
||||
|
||||
## License
|
||||
## 📖 License
|
||||
|
||||
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
|
||||
|
||||
MIT - Author Ettore Di Giacinto <mudler@localai.io>
|
||||
|
||||
## Acknowledgements
|
||||
## 🙇 Acknowledgements
|
||||
|
||||
LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
|
||||
|
||||
@@ -265,9 +410,9 @@ LocalAI couldn't have been built without the help of great software already avai
|
||||
- https://github.com/rhasspy/piper
|
||||
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
|
||||
|
||||
## Contributors
|
||||
## 🤗 Contributors
|
||||
|
||||
This is a community project, a special thanks to our contributors!
|
||||
This is a community project, a special thanks to our contributors! 🤗
|
||||
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
||||
</a>
|
||||
|
||||
@@ -46,12 +46,6 @@ service Backend {
|
||||
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||
|
||||
// Quantization RPCs
|
||||
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -179,7 +173,6 @@ message PredictOptions {
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
||||
float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled)
|
||||
}
|
||||
|
||||
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
||||
@@ -486,7 +479,7 @@ message ToolFormatMarkers {
|
||||
string id_field = 16; // e.g., "id"
|
||||
bool fun_name_is_key = 17;
|
||||
bool tools_array_wrapped = 18;
|
||||
reserved 19;
|
||||
bool uses_python_dicts = 19;
|
||||
|
||||
// Reasoning markers
|
||||
string reasoning_start = 20; // e.g., "<think>"
|
||||
@@ -644,37 +637,3 @@ message ExportModelRequest {
|
||||
string model = 5; // base model name (for merge operations)
|
||||
map<string, string> extra_options = 6;
|
||||
}
|
||||
|
||||
// Quantization messages
|
||||
|
||||
message QuantizationRequest {
|
||||
string model = 1; // HF model name or local path
|
||||
string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string output_dir = 3; // where to write output files
|
||||
string job_id = 4; // client-assigned job ID
|
||||
map<string, string> extra_options = 5; // hf_token, custom flags, etc.
|
||||
}
|
||||
|
||||
message QuantizationJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message QuantizationProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message QuantizationProgressUpdate {
|
||||
string job_id = 1;
|
||||
float progress_percent = 2;
|
||||
string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped
|
||||
string message = 4;
|
||||
string output_file = 5; // set when completed — path to the output GGUF file
|
||||
map<string, float> extra_metrics = 6; // e.g. file_size_mb, compression_ratio
|
||||
}
|
||||
|
||||
message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=d0a6dfeb28a09831d904fc4d910ddb740da82834
|
||||
LLAMA_VERSION?=5744d7ec430e2f875a393770195fda530560773f
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -22,10 +22,8 @@
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
#include <regex>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <signal.h>
|
||||
#include <thread>
|
||||
@@ -39,43 +37,6 @@ using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
// gRPC bearer token auth for distributed mode.
|
||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||
|
||||
// Cached auth token — empty means auth is disabled.
|
||||
static std::string g_grpc_auth_token;
|
||||
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns OK when auth is disabled or the token matches.
|
||||
static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||
if (g_grpc_auth_token.empty()) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
auto metadata = context->client_metadata();
|
||||
auto it = metadata.find("authorization");
|
||||
if (it != metadata.end()) {
|
||||
std::string expected = "Bearer " + g_grpc_auth_token;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
if (expected.size() == got.size() &&
|
||||
ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
// END LocalAI
|
||||
|
||||
|
||||
@@ -175,7 +136,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["mirostat_eta"] = predict->mirostateta();
|
||||
data["n_keep"] = predict->nkeep();
|
||||
data["seed"] = predict->seed();
|
||||
data["min_p"] = predict->minp();
|
||||
|
||||
|
||||
std::string grammar_str = predict->grammar();
|
||||
@@ -284,12 +244,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
// Speculative decoding per-request overrides
|
||||
// NDraft maps to speculative.n_max (maximum draft tokens per speculation step)
|
||||
if (predict->ndraft() > 0) {
|
||||
data["speculative.n_max"] = predict->ndraft();
|
||||
}
|
||||
|
||||
// Add the correlationid to json data
|
||||
data["correlation_id"] = predict->correlationid();
|
||||
|
||||
@@ -408,16 +362,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->mmproj().empty()) {
|
||||
params.mmproj.path = request->mmproj();
|
||||
}
|
||||
|
||||
// Draft model for speculative decoding
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.mparams_dft.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
params.model_alias.insert(request->modelfile());
|
||||
if (!request->cachetypekey().empty()) {
|
||||
@@ -625,48 +569,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (8)
|
||||
}
|
||||
}
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
auto type = common_speculative_type_from_name(optval_str);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_min") || !strcmp(optname, "draft_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_min = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_min") || !strcmp(optname, "draft_p_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_min = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_split")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_split = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_n") || !strcmp(optname, "ngram_size_n")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_n = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_m") || !strcmp(optname, "ngram_size_m")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_m = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_min_hits") || !strcmp(optname, "ngram_min_hits")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_min_hits = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_gpu_layers")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_gpu_layers = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_ctx_size")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_ctx = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -811,17 +713,13 @@ private:
|
||||
public:
|
||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||
|
||||
grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
// Implement Health RPC
|
||||
reply->set_message("OK");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(ctx_server, request, params);
|
||||
@@ -1020,8 +918,6 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -1309,7 +1205,6 @@ public:
|
||||
|
||||
body_json["messages"] = messages_json;
|
||||
body_json["stream"] = true; // PredictStream is always streaming
|
||||
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||
|
||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||
@@ -1617,11 +1512,8 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
// Without this, the PEG parser never produces diffs and the Go side
|
||||
// cannot detect tool calls or separate reasoning from content.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -1646,47 +1538,19 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text;
|
||||
|
||||
if (res_json.contains("choices")) {
|
||||
// OAI chat format — extract content from choices[0].delta
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & delta = choices[0].value("delta", json::object());
|
||||
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||
completion_text = delta.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = res_json.value("content", "");
|
||||
}
|
||||
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
|
||||
// Token counts: native format has top-level fields,
|
||||
// OAI format has them in "usage" (final chunk only)
|
||||
if (res_json.contains("usage")) {
|
||||
const auto & usage = res_json.at("usage");
|
||||
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||
} else {
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
}
|
||||
|
||||
// Timings: present as top-level "timings" in both formats
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
@@ -1695,12 +1559,6 @@ public:
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Attach chat deltas from the autoparser to a Reply.
|
||||
// When diffs are available, populate ChatDeltas on the reply.
|
||||
// The raw message is always preserved so the Go side can use it
|
||||
// for reasoning extraction and tool call parsing as a fallback
|
||||
// (important in distributed mode where ChatDeltas may not be
|
||||
// the primary parsing path).
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
@@ -1763,8 +1621,6 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2385,9 +2241,8 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -2418,48 +2273,25 @@ public:
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||
std::string completion_text;
|
||||
int32_t tokens_predicted = 0;
|
||||
int32_t tokens_evaluated = 0;
|
||||
|
||||
if (result_json.contains("choices")) {
|
||||
// OAI chat format
|
||||
const auto & choices = result_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
completion_text = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
if (result_json.contains("usage")) {
|
||||
const auto & usage = result_json.at("usage");
|
||||
tokens_predicted = usage.value("completion_tokens", 0);
|
||||
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = result_json.value("content", "");
|
||||
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
}
|
||||
reply->set_message(completion_text);
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Timings: present in both formats as a top-level "timings" object
|
||||
if (result_json.contains("timings")) {
|
||||
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply->set_logprobs(logprobs_json.dump());
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
@@ -2475,20 +2307,7 @@ public:
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
// Handle both native and OAI chat formats
|
||||
std::string result_content;
|
||||
if (res_json.contains("choices")) {
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
result_content = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result_content = res_json.value("content", "");
|
||||
}
|
||||
arr.push_back(result_content);
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
@@ -2520,8 +2339,6 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2702,9 +2519,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2872,6 +2687,7 @@ public:
|
||||
tf->set_id_field(ap.tools.format.id_field);
|
||||
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
||||
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
||||
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
|
||||
tf->set_function_field(ap.tools.format.function_field);
|
||||
|
||||
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
||||
@@ -2945,18 +2761,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
|
||||
// Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
g_grpc_auth_token = auth_token;
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
|
||||
@@ -24,9 +24,6 @@ if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
@@ -36,9 +33,6 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
ACESTEP_CPP_VERSION?=ab020a9aefcd364423e0665da12babc6b0c7b507
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -106,10 +106,10 @@ func TestLoadModel(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
|
||||
// Get base directory from main model file for relative paths
|
||||
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
|
||||
|
||||
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
@@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
||||
)
|
||||
|
||||
@@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
||||
var textEncoderModel, ditModel, vaeModel string
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
key, value, found := strings.Cut(oo, ":")
|
||||
if !found {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
switch parts[0] {
|
||||
case "text_encoder_model":
|
||||
textEncoderModel = value
|
||||
textEncoderModel = parts[1]
|
||||
case "dit_model":
|
||||
ditModel = value
|
||||
ditModel = parts[1]
|
||||
case "vae_model":
|
||||
vaeModel = value
|
||||
vaeModel = parts[1]
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type LLM struct {
|
||||
draftModel *llama.LLama
|
||||
}
|
||||
|
||||
|
||||
// Free releases GPU resources and frees the llama model
|
||||
// This should be called when the model is being unloaded to properly release VRAM
|
||||
func (llm *LLM) Free() error {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build debug
|
||||
// +build debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !debug
|
||||
// +build !debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
for i := 0; i < len(k1); i++ {
|
||||
dot += k1[i] * k2[i]
|
||||
}
|
||||
|
||||
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
for i := 0; i < len(k1); i++ {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
|
||||
// to one-shot (only difference is resampler batch boundaries).
|
||||
var maxDiff float64
|
||||
var sumDiffSq float64
|
||||
for i := range minLen {
|
||||
for i := 0; i < minLen; i++ {
|
||||
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
|
||||
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
||||
|
||||
var persistentMaxDiff, freshMaxDiff float64
|
||||
for i := range minLen {
|
||||
for i := 0; i < minLen; i++ {
|
||||
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
||||
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
||||
if pd > persistentMaxDiff {
|
||||
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
|
||||
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
||||
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
||||
|
||||
Expect(stddev/mean).To(BeNumerically("<", 0.15),
|
||||
Expect(stddev / mean).To(BeNumerically("<", 0.15),
|
||||
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
||||
|
||||
// Also check frequency is correct
|
||||
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
|
||||
|
||||
// Every sample must be identical — the resampler is deterministic
|
||||
var maxDiff float64
|
||||
for i := range len(oneShot) {
|
||||
for i := 0; i < len(oneShot); i++ {
|
||||
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
|
||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
||||
copy(hdr[8:12], "WAVE")
|
||||
copy(hdr[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
copy(hdr[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
||||
|
||||
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
|
||||
)
|
||||
|
||||
pcm := make([]byte, toneNumSamples*2)
|
||||
for i := range toneNumSamples {
|
||||
for i := 0; i < toneNumSamples; i++ {
|
||||
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
||||
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=8afbeb6ba9702c15d41a38296f2ab1fe5c829fa0
|
||||
STABLEDIFFUSION_GGML_VERSION?=545fac4f3fb0117a4e962b1a04cf933a7e635933
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -27,7 +27,107 @@
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
"euler",
|
||||
"euler_a",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"dpm++2s_a",
|
||||
"dpm++2m",
|
||||
"dpm++2mv2",
|
||||
"ipndm",
|
||||
"ipndm_v",
|
||||
"lcm",
|
||||
"ddim_trailing",
|
||||
"tcd",
|
||||
"res_multistep",
|
||||
"res_2s",
|
||||
};
|
||||
|
||||
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
|
||||
|
||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||
const char* schedulers[] = {
|
||||
"discrete",
|
||||
"karras",
|
||||
"exponential",
|
||||
"ays",
|
||||
"gits",
|
||||
"sgm_uniform",
|
||||
"simple",
|
||||
"smoothstep",
|
||||
"kl_optimal",
|
||||
"lcm",
|
||||
"bong_tangent",
|
||||
};
|
||||
|
||||
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
|
||||
|
||||
// New enum string arrays
|
||||
const char* rng_type_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
"cpu",
|
||||
};
|
||||
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
|
||||
|
||||
const char* prediction_str[] = {
|
||||
"epsilon",
|
||||
"v",
|
||||
"edm_v",
|
||||
"flow",
|
||||
"flux_flow",
|
||||
"flux2_flow",
|
||||
};
|
||||
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
|
||||
|
||||
const char* lora_apply_mode_str[] = {
|
||||
"auto",
|
||||
"immediately",
|
||||
"at_runtime",
|
||||
};
|
||||
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
|
||||
|
||||
constexpr const char* sd_type_str[] = {
|
||||
"f32", // 0
|
||||
"f16", // 1
|
||||
"q4_0", // 2
|
||||
"q4_1", // 3
|
||||
nullptr, // 4
|
||||
nullptr, // 5
|
||||
"q5_0", // 6
|
||||
"q5_1", // 7
|
||||
"q8_0", // 8
|
||||
"q8_1", // 9
|
||||
"q2_k", // 10
|
||||
"q3_k", // 11
|
||||
"q4_k", // 12
|
||||
"q5_k", // 13
|
||||
"q6_k", // 14
|
||||
"q8_k", // 15
|
||||
"iq2_xxs", // 16
|
||||
"iq2_xs", // 17
|
||||
"iq3_xxs", // 18
|
||||
"iq1_s", // 19
|
||||
"iq4_nl", // 20
|
||||
"iq3_s", // 21
|
||||
"iq2_s", // 22
|
||||
"iq4_xs", // 23
|
||||
"i8", // 24
|
||||
"i16", // 25
|
||||
"i32", // 26
|
||||
"i64", // 27
|
||||
"f64", // 28
|
||||
"iq1_m", // 29
|
||||
"bf16", // 30
|
||||
nullptr, nullptr, nullptr, // 31-33
|
||||
"tq1_0", // 34
|
||||
"tq2_0", // 35
|
||||
nullptr, nullptr, nullptr, // 36-38
|
||||
"mxfp4" // 39
|
||||
};
|
||||
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
|
||||
|
||||
sd_ctx_params_t ctx_params;
|
||||
sd_ctx_t* sd_c;
|
||||
@@ -496,45 +596,75 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
|
||||
|
||||
if (!strcmp(optname, "rng_type")) {
|
||||
rng_type_t parsed = str_to_rng_type(optval);
|
||||
if (parsed != RNG_TYPE_COUNT) {
|
||||
rng_type = parsed;
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
rng_type = (rng_type_t)found;
|
||||
fprintf(stderr, "Found rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "sampler_rng_type")) {
|
||||
rng_type_t parsed = str_to_rng_type(optval);
|
||||
if (parsed != RNG_TYPE_COUNT) {
|
||||
sampler_rng_type = parsed;
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
sampler_rng_type = (rng_type_t)found;
|
||||
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "prediction")) {
|
||||
prediction_t parsed = str_to_prediction(optval);
|
||||
if (parsed != PREDICTION_COUNT) {
|
||||
prediction = parsed;
|
||||
int found = -1;
|
||||
for (int m = 0; m < PREDICTION_COUNT; m++) {
|
||||
if (!strcmp(optval, prediction_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
prediction = (prediction_t)found;
|
||||
fprintf(stderr, "Found prediction: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "lora_apply_mode")) {
|
||||
lora_apply_mode_t parsed = str_to_lora_apply_mode(optval);
|
||||
if (parsed != LORA_APPLY_MODE_COUNT) {
|
||||
lora_apply_mode = parsed;
|
||||
int found = -1;
|
||||
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
|
||||
if (!strcmp(optval, lora_apply_mode_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
lora_apply_mode = (lora_apply_mode_t)found;
|
||||
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "wtype")) {
|
||||
sd_type_t parsed = str_to_sd_type(optval);
|
||||
if (parsed != SD_TYPE_COUNT) {
|
||||
wtype = parsed;
|
||||
int found = -1;
|
||||
for (int m = 0; m < SD_TYPE_COUNT; m++) {
|
||||
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
wtype = (sd_type_t)found;
|
||||
fprintf(stderr, "Found wtype: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
|
||||
@@ -605,25 +735,27 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
fprintf (stderr, "Created context: OK\n");
|
||||
|
||||
int sample_method_found = -1;
|
||||
sample_method_t sm = str_to_sample_method(sampler);
|
||||
if (sm != SAMPLE_METHOD_COUNT) {
|
||||
sample_method_found = (int)sm;
|
||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
||||
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
|
||||
if (!strcmp(sampler, sample_method_str[m])) {
|
||||
sample_method_found = m;
|
||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
||||
}
|
||||
}
|
||||
if (sample_method_found == -1) {
|
||||
sample_method_found = sd_get_default_sample_method(sd_ctx);
|
||||
fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found));
|
||||
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
|
||||
}
|
||||
sample_method = (sample_method_t)sample_method_found;
|
||||
|
||||
scheduler_t sched = str_to_scheduler(scheduler_str);
|
||||
if (sched != SCHEDULER_COUNT) {
|
||||
scheduler = sched;
|
||||
fprintf(stderr, "Found scheduler: %s\n", scheduler_str);
|
||||
for (int d = 0; d < SCHEDULER_COUNT; d++) {
|
||||
if (!strcmp(scheduler_str, schedulers[d])) {
|
||||
scheduler = (scheduler_t)d;
|
||||
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
|
||||
}
|
||||
}
|
||||
if (scheduler == SCHEDULER_COUNT) {
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler));
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
|
||||
}
|
||||
|
||||
sd_c = sd_ctx;
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=95ea8f9bfb03a15db08a8989966fd1ae3361e20d
|
||||
WHISPER_CPP_VERSION?=9386f239401074690479731c1e41683fbbeac557
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -726,7 +726,6 @@
|
||||
- TTS
|
||||
- &opus
|
||||
name: "opus"
|
||||
alias: "opus"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
|
||||
urls:
|
||||
- https://opus-codec.org/
|
||||
@@ -3081,31 +3080,3 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda13-trl
|
||||
## llama.cpp quantization backend
|
||||
- &llama-cpp-quantization
|
||||
name: "llama-cpp-quantization"
|
||||
alias: "llama-cpp-quantization"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
|
||||
description: |
|
||||
Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format,
|
||||
and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.).
|
||||
urls:
|
||||
- https://github.com/ggml-org/llama.cpp
|
||||
tags:
|
||||
- quantization
|
||||
- GGUF
|
||||
- CPU
|
||||
capabilities:
|
||||
default: "cpu-llama-cpp-quantization"
|
||||
metal: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "cpu-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-llama-cpp-quantization
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization
|
||||
|
||||
@@ -19,10 +19,6 @@ import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from acestep.inference import (
|
||||
GenerationParams,
|
||||
GenerationConfig,
|
||||
@@ -448,8 +444,6 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,10 +16,6 @@ import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
@@ -229,9 +225,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
|
||||
|
||||
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
|
||||
a valid Bearer token in the 'authorization' metadata header are rejected with
|
||||
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
|
||||
performed (backward compatible).
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class _AbortHandler(grpc.RpcMethodHandler):
|
||||
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_streaming = False
|
||||
self.response_streaming = False
|
||||
self.request_deserializer = None
|
||||
self.response_serializer = None
|
||||
self.unary_unary = self._abort
|
||||
self.unary_stream = None
|
||||
self.stream_unary = None
|
||||
self.stream_stream = None
|
||||
|
||||
@staticmethod
|
||||
def _abort(request, context):
|
||||
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
|
||||
|
||||
|
||||
class TokenAuthInterceptor(grpc.ServerInterceptor):
|
||||
"""Sync gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
self._abort_handler = _AbortHandler()
|
||||
|
||||
def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return self._abort_handler
|
||||
return continuation(handler_call_details)
|
||||
|
||||
|
||||
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
|
||||
"""Async gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
|
||||
async def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return _AbortHandler()
|
||||
return await continuation(handler_call_details)
|
||||
|
||||
|
||||
def get_auth_interceptors(*, aio: bool = False):
|
||||
"""Return a list of gRPC interceptors for bearer token auth.
|
||||
|
||||
Args:
|
||||
aio: If True, return async-compatible interceptors for grpc.aio.server().
|
||||
If False (default), return sync interceptors for grpc.server().
|
||||
|
||||
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||
"""
|
||||
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||
if not token:
|
||||
return []
|
||||
if aio:
|
||||
return [AsyncTokenAuthInterceptor(token)]
|
||||
return [TokenAuthInterceptor(token)]
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
grpcio-tools
|
||||
@@ -15,10 +15,6 @@ import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -97,9 +93,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
transformers==4.48.3
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torchaudio==2.4.1
|
||||
coqui-tts
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
@@ -22,10 +22,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
# Import dynamic loader for pipeline discovery
|
||||
from diffusers_dynamic_loader import (
|
||||
@@ -1046,9 +1042,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -15,10 +15,6 @@ import torch
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -169,8 +165,6 @@ def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
]
|
||||
,
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,10 +14,6 @@ import torch
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -74,9 +70,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -19,10 +19,6 @@ import numpy as np
|
||||
import json
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -428,8 +424,6 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,10 +16,6 @@ from kittentts import KittenTTS
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -81,9 +77,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,10 +16,6 @@ from kokoro import KPipeline
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -88,9 +84,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -21,8 +21,3 @@ if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# spaCy is a dependency of misaki (used by kokoro for English phonemization).
|
||||
# Pre-download the model here because at runtime the portable Python environment
|
||||
# has no pip/uv, so spacy's auto-download would fail.
|
||||
python -m spacy download en_core_web_sm
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: llama-cpp-quantization
|
||||
llama-cpp-quantization:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: llama-cpp-quantization
|
||||
@echo "Running llama-cpp-quantization..."
|
||||
bash run.sh
|
||||
@echo "llama-cpp-quantization run."
|
||||
|
||||
.PHONY: test
|
||||
test: llama-cpp-quantization
|
||||
@echo "Testing llama-cpp-quantization..."
|
||||
bash test.sh
|
||||
@echo "llama-cpp-quantization tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
@@ -1,426 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
llama.cpp quantization backend for LocalAI.
|
||||
|
||||
Downloads HuggingFace models, converts them to GGUF format using
|
||||
convert_hf_to_gguf.py, and quantizes using llama-quantize.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Tracks a running quantization job."""
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
self.process = None # subprocess handle for killing
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.jobs = {} # job_id -> ActiveJob
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=b"OK")
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual work happens in StartQuantization."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
if job_id in self.jobs:
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=False,
|
||||
message=f"Job {job_id} already exists",
|
||||
)
|
||||
|
||||
job = ActiveJob(job_id)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.thread = threading.Thread(
|
||||
target=self._do_quantization,
|
||||
args=(job, request),
|
||||
daemon=True,
|
||||
)
|
||||
job.thread.start()
|
||||
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Quantization job started",
|
||||
)
|
||||
|
||||
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
|
||||
update = backend_pb2.QuantizationProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
progress_percent=progress_percent,
|
||||
status=status,
|
||||
message=message,
|
||||
output_file=output_file,
|
||||
extra_metrics=extra_metrics or {},
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
|
||||
def _do_quantization(self, job, request):
|
||||
try:
|
||||
model = request.model
|
||||
quant_type = request.quantization_type or "q4_k_m"
|
||||
output_dir = request.output_dir
|
||||
extra_options = dict(request.extra_options) if request.extra_options else {}
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped before starting")
|
||||
return
|
||||
|
||||
# Step 1: Download / resolve model
|
||||
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
|
||||
|
||||
model_path = self._resolve_model(job, model, output_dir, extra_options)
|
||||
if model_path is None:
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during download")
|
||||
return
|
||||
|
||||
# Step 2: Convert to f16 GGUF
|
||||
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
|
||||
|
||||
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
|
||||
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return
|
||||
|
||||
# Step 3: Quantize
|
||||
# If the user requested f16, skip quantization — the f16 GGUF is the final output
|
||||
if quant_type.lower() in ("f16", "fp16"):
|
||||
output_file = f16_gguf_path
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Model converted to f16 GGUF: {output_file}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
return
|
||||
|
||||
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
|
||||
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
|
||||
|
||||
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
|
||||
return # error already sent
|
||||
|
||||
# Clean up f16 intermediate file to save disk space
|
||||
try:
|
||||
os.remove(f16_gguf_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Quantization complete: {quant_type}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
|
||||
def _resolve_model(self, job, model, output_dir, extra_options):
|
||||
"""Download model from HuggingFace or return local path."""
|
||||
# If it's a local path that exists, use it directly
|
||||
if os.path.isdir(model):
|
||||
return model
|
||||
|
||||
# If it looks like a GGUF file path, use it directly
|
||||
if os.path.isfile(model) and model.endswith(".gguf"):
|
||||
return model
|
||||
|
||||
# Download from HuggingFace
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
cache_dir = os.path.join(output_dir, "hf_cache")
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=model,
|
||||
cache_dir=cache_dir,
|
||||
token=hf_token,
|
||||
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
|
||||
)
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
|
||||
return local_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "gated" in error_msg.lower() or "access" in error_msg.lower():
|
||||
self._send_progress(
|
||||
job, "failed",
|
||||
f"Access denied for {model}. This model may be gated — "
|
||||
f"please accept the license at https://huggingface.co/{model} "
|
||||
f"and provide your HF token in extra_options.",
|
||||
)
|
||||
else:
|
||||
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
|
||||
return None
|
||||
|
||||
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
|
||||
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
|
||||
# If the model_path is already a GGUF file, just use it as-is
|
||||
if isinstance(model_path, str) and model_path.endswith(".gguf"):
|
||||
# Copy or symlink the GGUF file
|
||||
import shutil
|
||||
shutil.copy2(model_path, output_path)
|
||||
return True
|
||||
|
||||
# Find convert_hf_to_gguf.py
|
||||
convert_script = self._find_convert_script()
|
||||
if convert_script is None:
|
||||
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
sys.executable, convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16",
|
||||
]
|
||||
|
||||
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._send_progress(job, "converting", line, progress_percent=40.0)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _quantize(self, job, input_path, output_path, quant_type):
|
||||
"""Quantize a GGUF file using llama-quantize."""
|
||||
quantize_bin = self._find_quantize_binary()
|
||||
if quantize_bin is None:
|
||||
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
|
||||
return False
|
||||
|
||||
cmd = [quantize_bin, input_path, output_path, quant_type]
|
||||
|
||||
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
# Try to parse progress from llama-quantize output
|
||||
progress = self._parse_quantize_progress(line)
|
||||
pct = 55.0 + (progress * 0.40) if progress else 60.0
|
||||
self._send_progress(job, "quantizing", line, progress_percent=pct)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during quantization")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _parse_quantize_progress(self, line):
|
||||
"""Try to parse a progress percentage from llama-quantize output."""
|
||||
# llama-quantize typically outputs lines like:
|
||||
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
|
||||
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
|
||||
if match:
|
||||
current = int(match.group(1))
|
||||
total = int(match.group(2))
|
||||
if total > 0:
|
||||
return current / total
|
||||
return None
|
||||
|
||||
def _find_convert_script(self):
|
||||
"""Find convert_hf_to_gguf.py in known locations."""
|
||||
candidates = [
|
||||
# Same directory as this backend
|
||||
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
|
||||
# Installed via install.sh
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
|
||||
]
|
||||
|
||||
# Also check if it's on PATH
|
||||
import shutil
|
||||
path_script = shutil.which("convert_hf_to_gguf.py")
|
||||
if path_script:
|
||||
candidates.append(path_script)
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _find_quantize_binary(self):
|
||||
"""Find llama-quantize binary."""
|
||||
import shutil
|
||||
|
||||
# Check common names on PATH
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check in the backend directory (built by install.sh)
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
candidate = os.path.join(backend_dir, name)
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _file_metrics(self, filepath):
|
||||
"""Return file size metrics."""
|
||||
try:
|
||||
size_bytes = os.path.getsize(filepath)
|
||||
return {"file_size_mb": size_bytes / (1024 * 1024)}
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
def QuantizationProgress(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
yield update
|
||||
# If this is a terminal status, stop streaming
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
# Check if the thread is still alive
|
||||
if job.thread and not job.thread.is_alive():
|
||||
# Thread finished but no terminal update — drain queue
|
||||
while not job.progress_queue.empty():
|
||||
update = job.progress_queue.get_nowait()
|
||||
yield update
|
||||
break
|
||||
# Check if client disconnected
|
||||
if context.is_active() is False:
|
||||
break
|
||||
|
||||
def StopQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
|
||||
|
||||
job.stop_event.set()
|
||||
if job.process:
|
||||
try:
|
||||
job.process.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return backend_pb2.Result(success=True, message="Stop signal sent")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
|
||||
serve(args.addr)
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py from llama.cpp
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
|
||||
# Build llama-quantize from llama.cpp if not already present
|
||||
QUANTIZE_BIN="${EDIR}/llama-quantize"
|
||||
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||
if command -v cmake &>/dev/null; then
|
||||
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
||||
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||
fi
|
||||
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||
chmod +x "${QUANTIZE_BIN}"
|
||||
echo "Built llama-quantize at ${QUANTIZE_BIN}"
|
||||
else
|
||||
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
|
||||
fi
|
||||
fi
|
||||
@@ -1,5 +0,0 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
@@ -1,4 +0,0 @@
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
@@ -1,3 +0,0 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
@@ -1,10 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
Test script for the llama-cpp-quantization gRPC backend.
|
||||
|
||||
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
|
||||
and quantizes it to q4_k_m.
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
SERVER_ADDR = "localhost:50051"
|
||||
# Small model for CI testing (~540MB)
|
||||
TEST_MODEL = "unsloth/functiongemma-270m-it"
|
||||
|
||||
|
||||
class TestQuantizationBackend(unittest.TestCase):
|
||||
"""Tests for the llama-cpp-quantization gRPC service."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", SERVER_ADDR]
|
||||
)
|
||||
time.sleep(5)
|
||||
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.service.kill()
|
||||
cls.service.wait()
|
||||
# Clean up output directory
|
||||
if os.path.isdir(cls.output_dir):
|
||||
shutil.rmtree(cls.output_dir, ignore_errors=True)
|
||||
|
||||
def _channel(self):
|
||||
return grpc.insecure_channel(SERVER_ADDR)
|
||||
|
||||
def test_01_health(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
def test_02_quantize_small_model(self):
|
||||
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
|
||||
job_id = "test-quantize-001"
|
||||
|
||||
# Start quantization
|
||||
result = stub.StartQuantization(
|
||||
backend_pb2.QuantizationRequest(
|
||||
model=TEST_MODEL,
|
||||
quantization_type="q4_k_m",
|
||||
output_dir=self.output_dir,
|
||||
job_id=job_id,
|
||||
)
|
||||
)
|
||||
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
|
||||
self.assertEqual(result.job_id, job_id)
|
||||
|
||||
# Stream progress until completion
|
||||
final_status = None
|
||||
output_file = None
|
||||
for update in stub.QuantizationProgress(
|
||||
backend_pb2.QuantizationProgressRequest(job_id=job_id)
|
||||
):
|
||||
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
|
||||
final_status = update.status
|
||||
if update.output_file:
|
||||
output_file = update.output_file
|
||||
|
||||
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
|
||||
self.assertIsNotNone(output_file, "No output_file in progress updates")
|
||||
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
|
||||
|
||||
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
|
||||
with open(output_file, "rb") as f:
|
||||
magic = f.read(4)
|
||||
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
|
||||
|
||||
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
|
||||
size_mb = os.path.getsize(output_file) / (1024 * 1024)
|
||||
print(f" Output file size: {size_mb:.1f} MB")
|
||||
self.assertGreater(size_mb, 10, "Output file suspiciously small")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -15,10 +15,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_audio.tts.utils import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
@@ -440,9 +436,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -23,10 +23,6 @@ import tempfile
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -472,8 +468,6 @@ async def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -12,10 +12,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
@@ -450,9 +446,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -12,10 +12,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
@@ -425,9 +421,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -17,10 +17,6 @@ from moonshine_voice import (
|
||||
)
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -132,9 +128,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,10 +14,6 @@ import torch
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -123,9 +119,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -4,4 +4,3 @@ certifi
|
||||
packaging==24.1
|
||||
setuptools
|
||||
pyarrow==20.0.0
|
||||
pybind11
|
||||
|
||||
@@ -15,10 +15,6 @@ from neuttsair.neutts import NeuTTSAir
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -134,9 +130,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,10 +14,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import outetts
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -120,9 +116,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
|
||||
@@ -16,10 +16,6 @@ import torch
|
||||
from pocket_tts import TTSModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -229,9 +225,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,10 +14,6 @@ import torch
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -188,9 +184,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -23,10 +23,6 @@ import hashlib
|
||||
import pickle
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -904,8 +900,6 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,10 +14,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
@@ -101,9 +97,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
@@ -13,10 +13,6 @@ import base64
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import requests
|
||||
|
||||
@@ -143,9 +139,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,22 +16,16 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Backward-compat aliases for model types
|
||||
TYPE_ALIASES = {"Mamba": "MambaForCausalLM"}
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -58,11 +52,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the health status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Reply object that contains the health status of the backend service.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
model_name = request.Model
|
||||
"""
|
||||
A gRPC method that loads a model into memory.
|
||||
|
||||
Args:
|
||||
request: A LoadModelRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
|
||||
model_name = request.Model
|
||||
|
||||
# Check to see if the Model exists in the filesystem already.
|
||||
if os.path.exists(request.ModelFile):
|
||||
model_name = request.ModelFile
|
||||
@@ -73,9 +88,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.GenericTTS=False
|
||||
self.DiaTTS=False
|
||||
self.SentenceTransformer = False
|
||||
self.processor = None
|
||||
|
||||
device_map="cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
@@ -87,7 +101,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Parse options from request.Options
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when generating
|
||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||
@@ -109,7 +123,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||
|
||||
if self.CUDA:
|
||||
from transformers import BitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
if request.MainGPU:
|
||||
device_map=request.MainGPU
|
||||
else:
|
||||
@@ -126,31 +140,40 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
quantization = BitsAndBytesConfig(
|
||||
load_in_4bit=False,
|
||||
bnb_4bit_compute_dtype = None,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if XPU and request.Type == "AutoModelForCausalLM":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
if request.Type == "AutoModelForCausalLM":
|
||||
if XPU:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
else:
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
else:
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
elif request.Type == "OVModelForCausalLM":
|
||||
from optimum.intel.openvino import OVModelForCausalLM
|
||||
from openvino.runtime import Core
|
||||
@@ -162,12 +185,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
@@ -184,60 +209,59 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
ov_config=ovconfig,
|
||||
export=True,
|
||||
device=device_map)
|
||||
self.OV = True
|
||||
elif request.Type == "MusicgenForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "SentenceTransformer":
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "Mamba":
|
||||
autoTokenizer = False
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = MambaForCausalLM.from_pretrained(model_name)
|
||||
else:
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
ModelClass = AutoModel # default
|
||||
if model_type and hasattr(transformers_module, model_type):
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
print(f"Using model class: {model_type}", file=sys.stderr)
|
||||
else:
|
||||
print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr)
|
||||
|
||||
self.model = ModelClass.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute,
|
||||
)
|
||||
|
||||
# Try to load a processor (needed for TTS/audio models)
|
||||
try:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.GenericTTS = True
|
||||
print(f"Loaded processor for {model_name}", file=sys.stderr)
|
||||
except Exception:
|
||||
self.processor = None
|
||||
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
if request.ContextSize > 0:
|
||||
self.max_tokens = request.ContextSize
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||
|
||||
|
||||
if autoTokenizer:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
self.XPU = False
|
||||
|
||||
if XPU and self.OV == False:
|
||||
@@ -251,9 +275,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as err:
|
||||
print("Error:", err, file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
Args:
|
||||
request: An EmbeddingRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
An EmbeddingResult object that contains the calculated embeddings.
|
||||
"""
|
||||
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
max_length = 512
|
||||
@@ -266,13 +303,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
embeds = self.model.encode(request.Embeddings)
|
||||
else:
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
|
||||
# Create word embeddings
|
||||
if self.CUDA:
|
||||
encoded_input = encoded_input.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
@@ -280,11 +317,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
embeds = sentence_embeddings[0]
|
||||
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
set_seed(request.Seed)
|
||||
if request.TopP < 0 or request.TopP > 1:
|
||||
request.TopP = 1
|
||||
|
||||
|
||||
if request.TopK <= 0:
|
||||
request.TopK = 50
|
||||
|
||||
@@ -297,7 +334,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
request.Temperature == None
|
||||
|
||||
prompt = request.Prompt
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
@@ -326,10 +363,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True)
|
||||
config=dict(inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -350,18 +387,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
if XPU and self.OV == False:
|
||||
outputs = self.model.generate(inputs["input_ids"],
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
pad_token=self.tokenizer.eos_token_id)
|
||||
else:
|
||||
outputs = self.model.generate(**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -376,11 +413,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
gen = self._predict(request, context, streaming=False)
|
||||
res = await gen.__anext__()
|
||||
return res
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The predict stream result.
|
||||
"""
|
||||
iterations = self._predict(request, context, streaming=True)
|
||||
try:
|
||||
async for iteration in iterations:
|
||||
@@ -398,19 +455,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
# Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration
|
||||
model_type = self.options.get("model_type", "MusicgenForConditionalGeneration")
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
self.model = ModelClass.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
inputs = None
|
||||
if request.text == "":
|
||||
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||
elif request.HasField('src'):
|
||||
# TODO SECURITY CODE GOES HERE LOL
|
||||
# WHO KNOWS IF THIS WORKS???
|
||||
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||
|
||||
|
||||
if request.HasField('src_divisor'):
|
||||
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||
|
||||
|
||||
inputs = self.processor(
|
||||
audio=wsamples,
|
||||
sampling_rate=sample_rate,
|
||||
@@ -424,7 +480,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
if request.HasField('duration'):
|
||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||
guidance = self.options.get("guidance_scale", 3.0)
|
||||
@@ -434,97 +490,92 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.HasField('sample'):
|
||||
dosample = request.sample
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||
print("[transformers] SoundGeneration generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio_values = self.processor.batch_decode(audio_values)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio_values, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
|
||||
print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest containing text dialogue and generation parameters
|
||||
context: The gRPC context
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure
|
||||
"""
|
||||
try:
|
||||
text = request.text
|
||||
print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr)
|
||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
||||
|
||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
||||
text = [request.text]
|
||||
|
||||
# Process the input
|
||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
||||
|
||||
# Generate audio with parameters from options or defaults
|
||||
generation_params = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
||||
"temperature": self.options.get("temperature", 1.8),
|
||||
"top_p": self.options.get("top_p", 0.90),
|
||||
"top_k": self.options.get("top_k", 45)
|
||||
}
|
||||
|
||||
outputs = self.model.generate(**generation_params)
|
||||
|
||||
# Decode and save audio
|
||||
outputs = self.processor.batch_decode(outputs)
|
||||
self.processor.save_audio(outputs, request.dst)
|
||||
|
||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
# Build inputs based on processor capabilities
|
||||
if request.voice and os.path.exists(request.voice):
|
||||
# Voice cloning: use chat template with reference audio
|
||||
chat_template = [{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "audio", "path": request.voice},
|
||||
],
|
||||
}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
elif hasattr(self.processor, 'apply_chat_template'):
|
||||
# Models that use chat template format (VibeVoice, CSM, etc.)
|
||||
chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}]
|
||||
try:
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
except Exception:
|
||||
# Fallback if chat template fails (not all processors support it)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
else:
|
||||
# Direct processor call (Musicgen, etc.)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Build generation kwargs from self.options
|
||||
gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens}
|
||||
for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]:
|
||||
if key in self.options:
|
||||
gen_kwargs[key] = self.options[key]
|
||||
|
||||
# Add noise scheduler if configured (e.g., for VibeVoice)
|
||||
noise_scheduler_type = self.options.get("noise_scheduler", None)
|
||||
if noise_scheduler_type:
|
||||
import diffusers
|
||||
SchedulerClass = getattr(diffusers, noise_scheduler_type)
|
||||
scheduler_kwargs = {}
|
||||
for key in ["beta_schedule", "prediction_type"]:
|
||||
if key in self.options:
|
||||
scheduler_kwargs[key] = self.options[key]
|
||||
gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs)
|
||||
|
||||
# Generate audio
|
||||
audio = self.model.generate(**gen_kwargs)
|
||||
print("[transformers] TTS generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio = self.processor.batch_decode(audio)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy())
|
||||
|
||||
print("[transformers] TTS saved to", request.dst, file=sys.stderr)
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.DiaTTS:
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
if self.processor is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
inputs = self.processor(
|
||||
text=[request.text],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] TTS for", file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
@@ -536,9 +587,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -2,9 +2,7 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,7 @@ torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,7 @@
|
||||
torch==2.9.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -1,11 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.8.0+rocm6.4
|
||||
accelerate
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -3,9 +3,7 @@ torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,7 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.0.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.78.1
|
||||
protobuf==6.33.5
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -17,10 +17,6 @@ import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
@@ -836,8 +832,6 @@ def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -20,10 +20,6 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
|
||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -728,9 +724,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -27,10 +27,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
from vllm_omni.entrypoints.omni import Omni
|
||||
from vllm_omni.outputs import OmniRequestOutput
|
||||
@@ -654,9 +650,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -12,10 +12,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -342,9 +338,7 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
@@ -18,10 +18,6 @@ import backend_pb2_grpc
|
||||
import torch
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -301,9 +297,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -8,15 +8,6 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# The PyTorch CPU/CUDA indexes mirror common packages (e.g. requests) with
|
||||
# limited, often outdated version sets. uv's default "first-index" strategy
|
||||
# locks to the first index that carries a package, so it can pick e.g.
|
||||
# requests==2.28.1 from the PyTorch index instead of a newer version from
|
||||
# PyPI. voxcpm's transitive deps (datasets>=3 → requests>=2.32.2) need the
|
||||
# PyPI versions. "unsafe-best-match" is safe here because we control both
|
||||
# indexes and there is no dependency confusion risk.
|
||||
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||
|
||||
installRequirements
|
||||
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
torchaudio
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm>=1.5.0
|
||||
voxcpm
|
||||
torchcodec
|
||||
@@ -5,3 +5,4 @@ certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
|
||||
@@ -13,10 +13,6 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -141,9 +137,7 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -83,18 +83,8 @@ For documentation and support:
|
||||
cli.CLI.LogLevel = &logLevel
|
||||
}
|
||||
|
||||
// Set xlog logger with the desired level and text format.
|
||||
// xlog auto-enables log deduplication when output is a terminal.
|
||||
var logOpts []xlog.LoggerOption
|
||||
if cli.CLI.LogDedupLogs != nil {
|
||||
if *cli.CLI.LogDedupLogs {
|
||||
logOpts = append(logOpts, xlog.WithDedup())
|
||||
} else {
|
||||
logOpts = append(logOpts, xlog.WithoutDedup())
|
||||
}
|
||||
}
|
||||
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(*cli.CLI.LogLevel), *cli.CLI.LogFormat, logOpts...))
|
||||
// Set xlog logger with the desired level and text format
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(*cli.CLI.LogLevel), *cli.CLI.LogFormat))
|
||||
|
||||
// Run the thing!
|
||||
err = ctx.Run(&cli.CLI.Context)
|
||||
|
||||
@@ -3,7 +3,7 @@ package application
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -22,23 +22,13 @@ func (a *Application) RestartAgentJobService() error {
|
||||
}
|
||||
|
||||
// Create new service instance
|
||||
agentJobService := agentpool.NewAgentJobService(
|
||||
agentJobService := services.NewAgentJobService(
|
||||
a.ApplicationConfig(),
|
||||
a.ModelLoader(),
|
||||
a.ModelConfigLoader(),
|
||||
a.TemplatesEvaluator(),
|
||||
)
|
||||
|
||||
// Re-apply distributed wiring if available (matches startup.go logic)
|
||||
if d := a.Distributed(); d != nil {
|
||||
if d.Dispatcher != nil {
|
||||
agentJobService.SetDistributedBackends(d.Dispatcher)
|
||||
}
|
||||
if d.JobStore != nil {
|
||||
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||
}
|
||||
}
|
||||
|
||||
// Start the service
|
||||
err := agentJobService.Start(a.ApplicationConfig().Context)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,16 +2,12 @@ package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -24,9 +20,9 @@ type Application struct {
|
||||
applicationConfig *config.ApplicationConfig
|
||||
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
|
||||
templatesEvaluator *templates.Evaluator
|
||||
galleryService *galleryop.GalleryService
|
||||
agentJobService *agentpool.AgentJobService
|
||||
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
|
||||
galleryService *services.GalleryService
|
||||
agentJobService *services.AgentJobService
|
||||
agentPoolService atomic.Pointer[services.AgentPoolService]
|
||||
authDB *gorm.DB
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
@@ -34,9 +30,6 @@ type Application struct {
|
||||
p2pCtx context.Context
|
||||
p2pCancel context.CancelFunc
|
||||
agentJobMutex sync.Mutex
|
||||
|
||||
// Distributed mode services (nil when not in distributed mode)
|
||||
distributed *DistributedServices
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -71,15 +64,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||
return a.templatesEvaluator
|
||||
}
|
||||
|
||||
func (a *Application) GalleryService() *galleryop.GalleryService {
|
||||
func (a *Application) GalleryService() *services.GalleryService {
|
||||
return a.galleryService
|
||||
}
|
||||
|
||||
func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
||||
func (a *Application) AgentJobService() *services.AgentJobService {
|
||||
return a.agentJobService
|
||||
}
|
||||
|
||||
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||
func (a *Application) AgentPoolService() *services.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
@@ -93,53 +86,8 @@ func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
}
|
||||
|
||||
// Distributed returns the distributed services, or nil if not in distributed mode.
|
||||
func (a *Application) Distributed() *DistributedServices {
|
||||
return a.distributed
|
||||
}
|
||||
|
||||
// IsDistributed returns true if the application is running in distributed mode.
|
||||
func (a *Application) IsDistributed() bool {
|
||||
return a.distributed != nil
|
||||
}
|
||||
|
||||
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
|
||||
// This prevents the agent pool from failing during startup when workers haven't connected yet.
|
||||
func (a *Application) waitForHealthyWorker() {
|
||||
maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault()
|
||||
const basePoll = 2 * time.Second
|
||||
|
||||
xlog.Info("Waiting for at least one healthy backend worker before starting agent pool")
|
||||
deadline := time.Now().Add(maxWait)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
registered, err := a.distributed.Registry.List(context.Background())
|
||||
if err == nil {
|
||||
for _, n := range registered {
|
||||
if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy {
|
||||
xlog.Info("Healthy backend worker found", "node", n.Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add 0-1s jitter to prevent thundering-herd on the node registry
|
||||
jitter := time.Duration(rand.Int64N(int64(time.Second)))
|
||||
select {
|
||||
case <-a.applicationConfig.Context.Done():
|
||||
return
|
||||
case <-time.After(basePoll + jitter):
|
||||
}
|
||||
}
|
||||
xlog.Warn("No healthy backend worker found after waiting, proceeding anyway")
|
||||
}
|
||||
|
||||
// InstanceID returns the unique identifier for this frontend instance.
|
||||
func (a *Application) InstanceID() string {
|
||||
return a.applicationConfig.Distributed.InstanceID
|
||||
}
|
||||
|
||||
func (a *Application) start() error {
|
||||
galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -147,14 +95,19 @@ func (a *Application) start() error {
|
||||
|
||||
a.galleryService = galleryService
|
||||
|
||||
// Initialize agent job service (Start() is deferred to after distributed wiring)
|
||||
agentJobService := agentpool.NewAgentJobService(
|
||||
// Initialize agent job service
|
||||
agentJobService := services.NewAgentJobService(
|
||||
a.ApplicationConfig(),
|
||||
a.ModelLoader(),
|
||||
a.ModelConfigLoader(),
|
||||
a.TemplatesEvaluator(),
|
||||
)
|
||||
|
||||
err = agentJobService.Start(a.ApplicationConfig().Context)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.agentJobService = agentJobService
|
||||
|
||||
return nil
|
||||
@@ -167,56 +120,27 @@ func (a *Application) StartAgentPool() {
|
||||
if !a.applicationConfig.AgentPool.Enabled {
|
||||
return
|
||||
}
|
||||
// Build options struct from available dependencies
|
||||
opts := agentpool.AgentPoolOptions{
|
||||
AuthDB: a.authDB,
|
||||
}
|
||||
if d := a.Distributed(); d != nil {
|
||||
if d.DistStores != nil && d.DistStores.Skills != nil {
|
||||
opts.SkillStore = d.DistStores.Skills
|
||||
}
|
||||
opts.NATSClient = d.Nats
|
||||
opts.EventBridge = d.AgentBridge
|
||||
opts.AgentStore = d.AgentStore
|
||||
}
|
||||
|
||||
aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts)
|
||||
aps, err := services.NewAgentPoolService(a.applicationConfig)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to create agent pool service", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wire distributed mode components
|
||||
if d := a.Distributed(); d != nil {
|
||||
// Wait for at least one healthy backend worker before starting the agent pool.
|
||||
// Collections initialization calls embeddings which require a worker.
|
||||
if d.Registry != nil {
|
||||
a.waitForHealthyWorker()
|
||||
}
|
||||
if a.authDB != nil {
|
||||
aps.SetAuthDB(a.authDB)
|
||||
}
|
||||
|
||||
if err := aps.Start(a.applicationConfig.Context); err != nil {
|
||||
xlog.Error("Failed to start agent pool", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
|
||||
usm := agentpool.NewUserServicesManager(
|
||||
usm := services.NewUserServicesManager(
|
||||
aps.UserStorage(),
|
||||
a.applicationConfig,
|
||||
a.modelLoader,
|
||||
a.backendLoader,
|
||||
a.templatesEvaluator,
|
||||
)
|
||||
// Wire distributed backends to per-user job services
|
||||
if a.agentJobService != nil {
|
||||
if d := a.agentJobService.Dispatcher(); d != nil {
|
||||
usm.SetJobDispatcher(d)
|
||||
}
|
||||
if s := a.agentJobService.DBStore(); s != nil {
|
||||
usm.SetJobDBStore(s)
|
||||
}
|
||||
}
|
||||
aps.SetUserServicesManager(usm)
|
||||
|
||||
a.agentPoolService.Store(aps)
|
||||
|
||||
@@ -199,6 +199,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
|
||||
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
|
||||
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
|
||||
envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
|
||||
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
|
||||
envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold
|
||||
envThreads := appConfig.Threads == startupAppConfig.Threads
|
||||
@@ -270,6 +271,9 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
appConfig.MaxActiveBackends = 0
|
||||
}
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil && !envParallelRequests {
|
||||
appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled {
|
||||
appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
|
||||
if appConfig.MemoryReclaimerEnabled {
|
||||
|
||||
@@ -1,280 +0,0 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/agents"
|
||||
"github.com/mudler/LocalAI/core/services/distributed"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DistributedServices holds all services initialized for distributed mode.
|
||||
type DistributedServices struct {
|
||||
Nats *messaging.Client
|
||||
Store storage.ObjectStore
|
||||
Registry *nodes.NodeRegistry
|
||||
Router *nodes.SmartRouter
|
||||
Health *nodes.HealthMonitor
|
||||
Reconciler *nodes.ReplicaReconciler
|
||||
JobStore *jobs.JobStore
|
||||
Dispatcher *jobs.Dispatcher
|
||||
AgentStore *agents.AgentStore
|
||||
AgentBridge *agents.EventBridge
|
||||
DistStores *distributed.Stores
|
||||
FileMgr *storage.FileManager
|
||||
FileStager nodes.FileStager
|
||||
ModelAdapter *nodes.ModelRouterAdapter
|
||||
Unloader *nodes.RemoteUnloaderAdapter
|
||||
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// Shutdown stops all distributed services in reverse initialization order.
|
||||
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
|
||||
func (ds *DistributedServices) Shutdown() {
|
||||
if ds == nil {
|
||||
return
|
||||
}
|
||||
ds.shutdownOnce.Do(func() {
|
||||
if ds.Health != nil {
|
||||
ds.Health.Stop()
|
||||
}
|
||||
if ds.Dispatcher != nil {
|
||||
ds.Dispatcher.Stop()
|
||||
}
|
||||
if closer, ok := ds.Store.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
|
||||
// when the NATS client is closed below.
|
||||
if ds.Nats != nil {
|
||||
ds.Nats.Close()
|
||||
}
|
||||
xlog.Info("Distributed services shut down")
|
||||
})
|
||||
}
|
||||
|
||||
// initDistributed validates distributed mode prerequisites and initializes
|
||||
// NATS, object storage, node registry, and instance identity.
|
||||
// Returns nil if distributed mode is not enabled.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
|
||||
if !cfg.Distributed.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
xlog.Info("Distributed mode enabled — validating prerequisites")
|
||||
|
||||
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
|
||||
if err := cfg.Distributed.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
|
||||
if !cfg.Auth.Enabled {
|
||||
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
|
||||
}
|
||||
if !isPostgresURL(cfg.Auth.DatabaseURL) {
|
||||
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
|
||||
}
|
||||
|
||||
// Generate instance ID if not set
|
||||
if cfg.Distributed.InstanceID == "" {
|
||||
cfg.Distributed.InstanceID = uuid.New().String()
|
||||
}
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
|
||||
|
||||
// Ensure NATS is closed if any subsequent initialization step fails.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
natsClient.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize object storage
|
||||
var store storage.ObjectStore
|
||||
if cfg.Distributed.StorageURL != "" {
|
||||
if cfg.Distributed.StorageBucket == "" {
|
||||
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
|
||||
}
|
||||
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
|
||||
Endpoint: cfg.Distributed.StorageURL,
|
||||
Region: cfg.Distributed.StorageRegion,
|
||||
Bucket: cfg.Distributed.StorageBucket,
|
||||
AccessKeyID: cfg.Distributed.StorageAccessKey,
|
||||
SecretAccessKey: cfg.Distributed.StorageSecretKey,
|
||||
ForcePathStyle: true, // required for MinIO
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing S3 storage: %w", err)
|
||||
}
|
||||
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
|
||||
store = s3Store
|
||||
} else {
|
||||
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
|
||||
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
|
||||
}
|
||||
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
|
||||
store = fsStore
|
||||
}
|
||||
|
||||
// Initialize node registry (requires the auth DB which is PostgreSQL)
|
||||
if authDB == nil {
|
||||
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
|
||||
}
|
||||
|
||||
registry, err := nodes.NewNodeRegistry(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing node registry: %w", err)
|
||||
}
|
||||
xlog.Info("Node registry initialized")
|
||||
|
||||
// Collect SmartRouter option values; the router itself is created after all
|
||||
// dependencies (including FileStager and Unloader) are ready.
|
||||
var routerAuthToken string
|
||||
if cfg.Distributed.RegistrationToken != "" {
|
||||
routerAuthToken = cfg.Distributed.RegistrationToken
|
||||
}
|
||||
var routerGalleriesJSON string
|
||||
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
|
||||
routerGalleriesJSON = string(galleriesJSON)
|
||||
}
|
||||
|
||||
healthMon := nodes.NewHealthMonitor(registry, authDB,
|
||||
cfg.Distributed.HealthCheckIntervalOrDefault(),
|
||||
cfg.Distributed.StaleNodeThresholdOrDefault(),
|
||||
routerAuthToken,
|
||||
cfg.Distributed.PerModelHealthCheck,
|
||||
)
|
||||
|
||||
// Initialize job store
|
||||
jobStore, err := jobs.NewJobStore(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing job store: %w", err)
|
||||
}
|
||||
xlog.Info("Distributed job store initialized")
|
||||
|
||||
// Initialize job dispatcher
|
||||
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
|
||||
|
||||
// Initialize agent store
|
||||
agentStore, err := agents.NewAgentStore(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing agent store: %w", err)
|
||||
}
|
||||
xlog.Info("Distributed agent store initialized")
|
||||
|
||||
// Initialize agent event bridge
|
||||
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
|
||||
|
||||
// Start observable persister — captures observable_update events from workers
|
||||
// (which have no DB access) and persists them to PostgreSQL.
|
||||
if err := agentBridge.StartObservablePersister(); err != nil {
|
||||
xlog.Warn("Failed to start observable persister", "error", err)
|
||||
} else {
|
||||
xlog.Info("Observable persister started")
|
||||
}
|
||||
|
||||
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
|
||||
distStores, err := distributed.InitStores(authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing distributed stores: %w", err)
|
||||
}
|
||||
|
||||
// Initialize file manager with local cache
|
||||
cacheDir := cfg.DataPath + "/cache"
|
||||
fileMgr, err := storage.NewFileManager(store, cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing file manager: %w", err)
|
||||
}
|
||||
xlog.Info("File manager initialized", "cacheDir", cacheDir)
|
||||
|
||||
// Create FileStager for distributed file transfer
|
||||
var fileStager nodes.FileStager
|
||||
if cfg.Distributed.StorageURL != "" {
|
||||
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
|
||||
xlog.Info("File stager initialized (S3+NATS)")
|
||||
} else {
|
||||
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
node, err := registry.Get(context.Background(), nodeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if node.HTTPAddress == "" {
|
||||
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
|
||||
}
|
||||
return node.HTTPAddress, nil
|
||||
}, cfg.Distributed.RegistrationToken)
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas
|
||||
reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: router,
|
||||
Unloader: remoteUnloader,
|
||||
DB: authDB,
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
modelAdapter := nodes.NewModelRouterAdapter(router)
|
||||
|
||||
success = true
|
||||
return &DistributedServices{
|
||||
Nats: natsClient,
|
||||
Store: store,
|
||||
Registry: registry,
|
||||
Router: router,
|
||||
Health: healthMon,
|
||||
Reconciler: reconciler,
|
||||
JobStore: jobStore,
|
||||
Dispatcher: dispatcher,
|
||||
AgentStore: agentStore,
|
||||
AgentBridge: agentBridge,
|
||||
DistStores: distStores,
|
||||
FileMgr: fileMgr,
|
||||
FileStager: fileStager,
|
||||
ModelAdapter: modelAdapter,
|
||||
Unloader: remoteUnloader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isPostgresURL(url string) bool {
|
||||
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/mudler/edgevpn/pkg/node"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -146,14 +146,22 @@ func (a *Application) RestartP2P() error {
|
||||
return fmt.Errorf("P2P token is not set")
|
||||
}
|
||||
|
||||
// Create new context for P2P
|
||||
ctx, cancel := context.WithCancel(appConfig.Context)
|
||||
a.p2pCtx = ctx
|
||||
a.p2pCancel = cancel
|
||||
|
||||
// Get API address from config
|
||||
address := appConfig.APIAddress
|
||||
if address == "" {
|
||||
address = "127.0.0.1:8080" // default
|
||||
}
|
||||
|
||||
// Start P2P stack in a goroutine
|
||||
// Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel
|
||||
go func() {
|
||||
if err := a.StartP2P(); err != nil {
|
||||
xlog.Error("Failed to start P2P stack", "error", err)
|
||||
if a.p2pCancel != nil {
|
||||
a.p2pCancel()
|
||||
}
|
||||
cancel() // Cancel context on error
|
||||
}
|
||||
}()
|
||||
xlog.Info("P2P stack restarted with new settings")
|
||||
@@ -220,7 +228,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error {
|
||||
continue
|
||||
}
|
||||
|
||||
app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: model,
|
||||
Galleries: app.ApplicationConfig().Galleries,
|
||||
|
||||
@@ -13,15 +13,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -105,7 +101,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
|
||||
}
|
||||
application.authDB = authDB
|
||||
xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL))
|
||||
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
|
||||
|
||||
// Start session and expired API key cleanup goroutine
|
||||
go func() {
|
||||
@@ -127,96 +123,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
dbJobStore, err := jobs.NewJobStore(application.authDB)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to create job store for auth DB", "error", err)
|
||||
} else {
|
||||
application.agentJobService.SetDistributedJobStore(dbJobStore)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize distributed mode services (NATS, object storage, node registry)
|
||||
distSvc, err := initDistributed(options, application.authDB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
|
||||
}
|
||||
if distSvc != nil {
|
||||
application.distributed = distSvc
|
||||
// Wire remote model unloader so ShutdownModel works for remote nodes
|
||||
// Uses NATS to tell serve-backend nodes to Free + kill their backend process
|
||||
application.modelLoader.SetRemoteUnloader(distSvc.Unloader)
|
||||
// Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode
|
||||
application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter())
|
||||
// Wire DistributedModelStore so shutdown/list/watchdog can find remote models
|
||||
distStore := nodes.NewDistributedModelStore(
|
||||
model.NewInMemoryModelStore(),
|
||||
distSvc.Registry,
|
||||
)
|
||||
application.modelLoader.SetModelStore(distStore)
|
||||
// Start health monitor
|
||||
distSvc.Health.Start(options.Context)
|
||||
// Start replica reconciler for auto-scaling model replicas
|
||||
if distSvc.Reconciler != nil {
|
||||
go distSvc.Reconciler.Run(options.Context)
|
||||
}
|
||||
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
|
||||
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
|
||||
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
|
||||
// but does NOT set a workerFn — agent workers consume jobs from the same NATS queue.
|
||||
|
||||
// Wire model config loader so job events include model config for agent workers
|
||||
distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader)
|
||||
|
||||
// Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched
|
||||
if err := distSvc.Dispatcher.Start(options.Context); err != nil {
|
||||
return nil, fmt.Errorf("starting job dispatcher: %w", err)
|
||||
}
|
||||
// Start ephemeral file cleanup
|
||||
storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0)
|
||||
// Wire distributed backends into AgentJobService (before Start)
|
||||
if application.agentJobService != nil {
|
||||
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||
}
|
||||
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||
|
||||
// Wire NATS and gallery store into GalleryService for cross-instance progress/cancel
|
||||
if application.galleryService != nil {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Start AgentJobService (after distributed wiring so it knows whether to use local or NATS)
|
||||
if application.agentJobService != nil {
|
||||
if err := application.agentJobService.Start(options.Context); err != nil {
|
||||
return nil, fmt.Errorf("starting agent job service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
xlog.Error("error installing models", "error", err)
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -242,13 +154,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -272,7 +184,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
xlog.Debug("Context canceled, shutting down")
|
||||
application.distributed.Shutdown()
|
||||
err := application.ModelLoader().StopAllGRPC()
|
||||
if err != nil {
|
||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||
@@ -296,7 +207,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
var backendErr error
|
||||
_, backendErr = application.ModelLoader().Load(o...)
|
||||
if backendErr != nil {
|
||||
return nil, backendErr
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -443,6 +354,11 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil {
|
||||
if !options.ParallelBackendRequests {
|
||||
options.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
}
|
||||
if settings.MemoryReclaimerEnabled != nil {
|
||||
// Only apply if current value is default (false), suggesting it wasn't set from env var
|
||||
if !options.MemoryReclaimerEnabled {
|
||||
@@ -516,18 +432,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tracing settings
|
||||
if settings.EnableTracing != nil {
|
||||
if !options.EnableTracing {
|
||||
options.EnableTracing = *settings.EnableTracing
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxItems != nil {
|
||||
if options.TracingMaxItems == 0 {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Debug("Runtime settings loaded from runtime_settings.json")
|
||||
}
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -27,7 +27,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
|
||||
}
|
||||
|
||||
@@ -36,27 +36,6 @@ type TokenUsage struct {
|
||||
Completion int
|
||||
TimingPromptProcessing float64
|
||||
TimingTokenGeneration float64
|
||||
ChatDeltas []*proto.ChatDelta // per-chunk deltas from C++ autoparser (only set during streaming)
|
||||
}
|
||||
|
||||
// HasChatDeltaContent returns true if any chat delta carries content or reasoning text.
|
||||
// Used to decide whether to prefer C++ autoparser deltas over Go-side tag extraction.
|
||||
func (t TokenUsage) HasChatDeltaContent() bool {
|
||||
for _, d := range t.ChatDeltas {
|
||||
if d.Content != "" || d.ReasoningContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChatDeltaReasoningAndContent extracts accumulated reasoning and content from chat deltas.
|
||||
func (t TokenUsage) ChatDeltaReasoningAndContent() (reasoning, content string) {
|
||||
for _, d := range t.ChatDeltas {
|
||||
content += d.Content
|
||||
reasoning += d.ReasoningContent
|
||||
}
|
||||
return reasoning, content
|
||||
}
|
||||
|
||||
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
||||
@@ -68,18 +47,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS)
|
||||
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelName := c.Name
|
||||
if modelName == "" {
|
||||
modelName = c.Model
|
||||
}
|
||||
if !slices.Contains(modelNames, modelName) {
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||
//return nil, err
|
||||
@@ -192,9 +167,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
||||
}
|
||||
|
||||
// Attach per-chunk chat deltas to tokenUsage so the callback can use them
|
||||
tokenUsage.ChatDeltas = reply.ChatDeltas
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
@@ -224,9 +196,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if len(msg) == 0 {
|
||||
tokenCallback("", tokenUsage)
|
||||
}
|
||||
|
||||
// Clear per-chunk deltas so they don't leak to the next chunk
|
||||
tokenUsage.ChatDeltas = nil
|
||||
})
|
||||
if len(allChatDeltas) > 0 {
|
||||
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
||||
@@ -283,12 +252,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"function_template": c.TemplateConfig.Functions,
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
. "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -108,111 +107,3 @@ var _ = Describe("LLM tests", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TokenUsage ChatDelta helpers", func() {
|
||||
Describe("HasChatDeltaContent", func() {
|
||||
It("should return false when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when ChatDeltas is empty", func() {
|
||||
usage := TokenUsage{ChatDeltas: []*pb.ChatDelta{}}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when all deltas have empty content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "", ReasoningContent: ""},
|
||||
{Content: ""},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return true when a delta has content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello"},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has reasoning content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has both content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello", ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ChatDeltaReasoningAndContent", func() {
|
||||
It("should return empty strings when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate content from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "Hello"},
|
||||
{Content: " world"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(content).To(Equal("Hello world"))
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate reasoning from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "step 1"},
|
||||
{ReasoningContent: " step 2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("step 1 step 2"))
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should separate reasoning and content from mixed deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking"},
|
||||
{Content: "answer"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("thinking"))
|
||||
Expect(content).To(Equal("answer"))
|
||||
})
|
||||
|
||||
It("should handle deltas with both fields set", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "a", ReasoningContent: "r1"},
|
||||
{Content: "b", ReasoningContent: "r2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("r1r2"))
|
||||
Expect(content).To(Equal("ab"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -59,7 +59,9 @@ func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...mo
|
||||
grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath)
|
||||
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
|
||||
|
||||
defOpts = append(defOpts, model.EnableParallelRequests)
|
||||
if so.ParallelBackendRequests {
|
||||
defOpts = append(defOpts, model.EnableParallelRequests)
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
@@ -84,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 {
|
||||
}
|
||||
|
||||
if seed == config.RAND_SEED {
|
||||
seed = rand.Int32()
|
||||
seed = rand.Int31()
|
||||
}
|
||||
|
||||
return seed
|
||||
@@ -250,7 +252,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
TopK: int32(*c.TopK),
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VAD(request *schema.VADRequest,
|
||||
ctx context.Context,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := proto.VADRequest{
|
||||
Audio: request.Audio,
|
||||
}
|
||||
resp, err := vadModel.VAD(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segments := []schema.VADSegment{}
|
||||
for _, s := range resp.Segments {
|
||||
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
|
||||
}
|
||||
|
||||
return &schema.VADResponse{
|
||||
Segments: segments,
|
||||
}, nil
|
||||
}
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VAD(request *schema.VADRequest,
|
||||
ctx context.Context,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := proto.VADRequest{
|
||||
Audio: request.Audio,
|
||||
}
|
||||
resp, err := vadModel.VAD(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segments := []schema.VADSegment{}
|
||||
for _, s := range resp.Segments {
|
||||
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
|
||||
}
|
||||
|
||||
return &schema.VADResponse{
|
||||
Segments: segments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
appConfig := r.buildAppConfig()
|
||||
|
||||
poolService, err := agentpool.NewAgentPoolService(appConfig)
|
||||
poolService, err := services.NewAgentPoolService(appConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent pool service: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,172 +4,211 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
)
|
||||
|
||||
var _ = Describe("AgentRunCMD", func() {
|
||||
Describe("loadAgentConfig", func() {
|
||||
It("loads agent config from file", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
|
||||
// Create a temporary agent config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
cfg := state.AgentConfig{
|
||||
Name: "test-agent",
|
||||
Model: "llama3",
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
|
||||
cfg := state.AgentConfig{
|
||||
Name: "test-agent",
|
||||
Model: "llama3",
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configFile, data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(loaded.Name).To(Equal("test-agent"))
|
||||
Expect(loaded.Model).To(Equal("llama3"))
|
||||
})
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "test-agent" {
|
||||
t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "llama3" {
|
||||
t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
|
||||
}
|
||||
}
|
||||
|
||||
It("loads agent config from pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"my-agent": {
|
||||
Model: "gpt-4",
|
||||
Description: "A test agent",
|
||||
SystemPrompt: "Hello",
|
||||
},
|
||||
"other-agent": {
|
||||
Model: "llama3",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
pool := map[string]state.AgentConfig{
|
||||
"my-agent": {
|
||||
Model: "gpt-4",
|
||||
Description: "A test agent",
|
||||
SystemPrompt: "Hello",
|
||||
},
|
||||
"other-agent": {
|
||||
Model: "llama3",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "my-agent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "my-agent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(loaded.Name).To(Equal("my-agent"))
|
||||
Expect(loaded.Model).To(Equal("gpt-4"))
|
||||
})
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "my-agent" {
|
||||
t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "gpt-4" {
|
||||
t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
|
||||
}
|
||||
}
|
||||
|
||||
It("returns error for missing agent in pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"existing-agent": {Model: "llama3"},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
pool := map[string]state.AgentConfig{
|
||||
"existing-agent": {Model: "llama3"},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "nonexistent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "nonexistent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
_, err = cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
_, err = cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing agent, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
It("returns error when no pool.json exists", func() {
|
||||
cmd := &AgentRunCMD{
|
||||
StateDir: GinkgoT().TempDir(),
|
||||
}
|
||||
func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
|
||||
cmd := &AgentRunCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no pool.json exists, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
It("returns error for config with no name", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
cfg := state.AgentConfig{
|
||||
Model: "llama3",
|
||||
}
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
|
||||
cmd := &AgentRunCMD{
|
||||
APIURL: "http://localhost:9090",
|
||||
APIKey: "secret",
|
||||
DefaultModel: "my-model",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
if cfg.APIURL != "http://localhost:9090" {
|
||||
t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
|
||||
}
|
||||
if cfg.APIKey != "secret" {
|
||||
t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
|
||||
}
|
||||
if cfg.Model != "my-model" {
|
||||
t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
Describe("applyOverrides", func() {
|
||||
It("applies overrides to empty fields", func() {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
}
|
||||
func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
Model: "existing-model",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
APIURL: "http://localhost:9090",
|
||||
APIKey: "secret",
|
||||
DefaultModel: "my-model",
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
DefaultModel: "override-model",
|
||||
}
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
Expect(cfg.APIURL).To(Equal("http://localhost:9090"))
|
||||
Expect(cfg.APIKey).To(Equal("secret"))
|
||||
Expect(cfg.Model).To(Equal("my-model"))
|
||||
})
|
||||
if cfg.Model != "existing-model" {
|
||||
t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
It("does not overwrite existing model", func() {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
Model: "existing-model",
|
||||
}
|
||||
func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
DefaultModel: "override-model",
|
||||
}
|
||||
// Agent config with no name
|
||||
cfg := state.AgentConfig{
|
||||
Model: "llama3",
|
||||
}
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configFile, data, 0644)
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
Expect(cfg.Model).To(Equal("existing-model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for config with no name, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("AgentListCMD", func() {
|
||||
It("runs without error when no pool file exists", func() {
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: GinkgoT().TempDir(),
|
||||
}
|
||||
Expect(cmd.Run(nil)).To(Succeed())
|
||||
})
|
||||
func TestAgentListCMD_NoPoolFile(t *testing.T) {
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
|
||||
It("runs without error with agents in pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
// Should not error, just print "no agents found"
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"agent-a": {Model: "llama3", Description: "First agent"},
|
||||
"agent-b": {Model: "gpt-4"},
|
||||
}
|
||||
data, _ := json.MarshalIndent(pool, "", " ")
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
func TestAgentListCMD_WithAgents(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
Expect(cmd.Run(nil)).To(Succeed())
|
||||
})
|
||||
})
|
||||
pool := map[string]state.AgentConfig{
|
||||
"agent-a": {Model: "llama3", Description: "First agent"},
|
||||
"agent-b": {Model: "gpt-4"},
|
||||
}
|
||||
data, _ := json.MarshalIndent(pool, "", " ")
|
||||
os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
|
||||
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user