mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-20 06:35:41 -04:00
Compare commits
83 Commits
feat/turbo
...
feat/backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fe87cb0d5 | ||
|
|
6dd37a95c4 | ||
|
|
ee00a10836 | ||
|
|
948f3bfaa4 | ||
|
|
1e083cd870 | ||
|
|
b19e60d03a | ||
|
|
4d463e9f0d | ||
|
|
ae4ae5f425 | ||
|
|
7c1865b307 | ||
|
|
62a674ce12 | ||
|
|
c39213443b | ||
|
|
606f462da4 | ||
|
|
5c35e85fe2 | ||
|
|
062e0d0d00 | ||
|
|
d4cd6c284f | ||
|
|
3bb8b65d31 | ||
|
|
9748a1cbc6 | ||
|
|
6bc76dda6d | ||
|
|
e1a6010874 | ||
|
|
706cf5d43c | ||
|
|
13a6ed709c | ||
|
|
85be4ff03c | ||
|
|
b0d9ce4905 | ||
|
|
7081b54c09 | ||
|
|
2b05420f95 | ||
|
|
b64347b6aa | ||
|
|
e00ce981f0 | ||
|
|
285f7d4340 | ||
|
|
ea6e850809 | ||
|
|
b7247fc148 | ||
|
|
39c6b3ed66 | ||
|
|
0e9d1a6588 | ||
|
|
510d6759fe | ||
|
|
154fa000d3 | ||
|
|
0526e60f8d | ||
|
|
db600fb5b2 | ||
|
|
9ac1bdc587 | ||
|
|
fdc9f7bf35 | ||
|
|
8e59346091 | ||
|
|
e6e4e19633 | ||
|
|
505c417fa7 | ||
|
|
17215f6fbc | ||
|
|
bccaba1f66 | ||
|
|
0f9d516a6c | ||
|
|
33b124c6f1 | ||
|
|
6b8007e88e | ||
|
|
b3837c2078 | ||
|
|
92f99b1ec3 | ||
|
|
ad232fdb1a | ||
|
|
11637b5a1b | ||
|
|
0dda4fe6f0 | ||
|
|
773489eeb1 | ||
|
|
06fbe48b3f | ||
|
|
232e324a68 | ||
|
|
39c954764c | ||
|
|
9b7d5513fc | ||
|
|
84cd8c0e7f | ||
|
|
d990f2790c | ||
|
|
53deeb1107 | ||
|
|
c5a840f6af | ||
|
|
6d9d77d590 | ||
|
|
6f304d1201 | ||
|
|
557d0f0f04 | ||
|
|
b7e3589875 | ||
|
|
716ddd697b | ||
|
|
223deb908d | ||
|
|
9f8821bba8 | ||
|
|
84e51b68ef | ||
|
|
7962dd16f7 | ||
|
|
a1466b305a | ||
|
|
57c0026715 | ||
|
|
1ed6b9e5ed | ||
|
|
e4ee74354f | ||
|
|
8577bdcebc | ||
|
|
0d489c7a0d | ||
|
|
11dc54bda9 | ||
|
|
7e0b73deaa | ||
|
|
c0a023d13d | ||
|
|
0d3ae1c295 | ||
|
|
e9f10f2f50 | ||
|
|
b95b0b72ff | ||
|
|
26f1b94f4d | ||
|
|
2d40725ca2 |
111
.agents/adding-gallery-models.md
Normal file
111
.agents/adding-gallery-models.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Adding GGUF Models from HuggingFace to the Gallery
|
||||
|
||||
When adding a GGUF model from HuggingFace to the LocalAI model gallery, follow this guide.
|
||||
|
||||
## Gallery file
|
||||
|
||||
All models are defined in `gallery/index.yaml`. Find the appropriate section (embedding models near other embeddings, chat models near similar chat models) and add a new entry.
|
||||
|
||||
## Getting the SHA256
|
||||
|
||||
GGUF files on HuggingFace expose their SHA256 via the `x-linked-etag` HTTP header. Fetch it with:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/<org>/<repo>/resolve/main/<filename>.gguf" | grep -i x-linked-etag
|
||||
```
|
||||
|
||||
The value (without quotes) is the SHA256 hash. Example:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/resolve/main/embeddinggemma-300m-qat-Q8_0.gguf" | grep -i x-linked-etag
|
||||
# x-linked-etag: "6fa0c02a9c302be6f977521d399b4de3a46310a4f2621ee0063747881b673f67"
|
||||
```
|
||||
|
||||
**Important**: Pay attention to exact filename casing — HuggingFace filenames are case-sensitive (e.g., `Q8_0` vs `q8_0`). Check the repo's file listing to get the exact name.
|
||||
|
||||
## Entry format — Embedding models
|
||||
|
||||
Embedding models use `gallery/virtual.yaml` as the base config and set `embeddings: true`:
|
||||
|
||||
```yaml
|
||||
- name: "model-name"
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
- https://huggingface.co/<original-model-org>/<original-model-name>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo-name>
|
||||
description: |
|
||||
Short description of the model, its size, and capabilities.
|
||||
tags:
|
||||
- embeddings
|
||||
overrides:
|
||||
backend: llama-cpp
|
||||
embeddings: true
|
||||
parameters:
|
||||
model: <filename>.gguf
|
||||
files:
|
||||
- filename: <filename>.gguf
|
||||
uri: huggingface://<gguf-org>/<gguf-repo-name>/<filename>.gguf
|
||||
sha256: <sha256-hash>
|
||||
```
|
||||
|
||||
## Entry format — Chat/LLM models
|
||||
|
||||
Chat models typically reference a template config (e.g., `gallery/gemma.yaml`, `gallery/chatml.yaml`) that defines the prompt format. Use YAML anchors (`&name` / `*name`) if adding multiple quantization variants of the same model:
|
||||
|
||||
```yaml
|
||||
- &model-anchor
|
||||
url: "github:mudler/LocalAI/gallery/<template>.yaml@master"
|
||||
name: "model-name"
|
||||
icon: https://example.com/icon.png
|
||||
license: <license>
|
||||
urls:
|
||||
- https://huggingface.co/<org>/<model>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo>
|
||||
description: |
|
||||
Model description.
|
||||
tags:
|
||||
- llm
|
||||
- gguf
|
||||
- gpu
|
||||
- cpu
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q4_K_M.gguf
|
||||
files:
|
||||
- filename: <filename>-Q4_K_M.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
To add a variant (e.g., different quantization), use YAML merge:
|
||||
|
||||
```yaml
|
||||
- !!merge <<: *model-anchor
|
||||
name: "model-name-q8"
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q8_0.gguf
|
||||
files:
|
||||
- filename: <filename>-Q8_0.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q8_0.gguf
|
||||
```
|
||||
|
||||
## Available template configs
|
||||
|
||||
Look at existing `.yaml` files in `gallery/` to find the right prompt template for your model architecture:
|
||||
|
||||
- `gemma.yaml` — Gemma-family models (gemma, embeddinggemma, etc.)
|
||||
- `chatml.yaml` — ChatML format (many Mistral/OpenHermes models)
|
||||
- `deepseek.yaml` — DeepSeek models
|
||||
- `virtual.yaml` — Minimal base (good for embedding models that don't need chat templates)
|
||||
|
||||
## Checklist
|
||||
|
||||
1. **Find the GGUF file** on HuggingFace — note exact filename (case-sensitive)
|
||||
2. **Get the SHA256** using the `curl -sI` + `x-linked-etag` method above
|
||||
3. **Choose the right template** config from `gallery/` based on model architecture
|
||||
4. **Add the entry** to `gallery/index.yaml` near similar models
|
||||
5. **Set `embeddings: true`** if it's an embedding model
|
||||
6. **Include both URLs** — the original model page and the GGUF repo
|
||||
7. **Write a description** — mention model size, capabilities, and quantization type
|
||||
1
.github/gallery-agent/agent.go
vendored
1
.github/gallery-agent/agent.go
vendored
@@ -133,6 +133,7 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
result, err := cogito.ExecuteTools(llm, fragment,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.DisableSinkState,
|
||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
15
.github/gallery-agent/gallery.go
vendored
15
.github/gallery-agent/gallery.go
vendored
@@ -79,7 +79,20 @@ func generateYAMLEntry(model ProcessedModel, quantization string) string {
|
||||
description = cleanTextContent(description)
|
||||
formattedDescription := formatTextContent(description)
|
||||
|
||||
configFile := formatTextContent(modelConfig.ConfigFile)
|
||||
// Strip name and description from config file since they are
|
||||
// already present at the gallery entry level and should not
|
||||
// appear under overrides.
|
||||
configFileContent := modelConfig.ConfigFile
|
||||
var cfgMap map[string]any
|
||||
if err := yaml.Unmarshal([]byte(configFileContent), &cfgMap); err == nil {
|
||||
delete(cfgMap, "name")
|
||||
delete(cfgMap, "description")
|
||||
if cleaned, err := yaml.Marshal(cfgMap); err == nil {
|
||||
configFileContent = string(cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
configFile := formatTextContent(configFileContent)
|
||||
|
||||
filesYAML, _ := yaml.Marshal(modelConfig.Files)
|
||||
|
||||
|
||||
2
.github/gallery-agent/testing.go
vendored
2
.github/gallery-agent/testing.go
vendored
@@ -17,7 +17,7 @@ func runSyntheticMode() error {
|
||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||
|
||||
var models []ProcessedModel
|
||||
for i := range numModels {
|
||||
for range numModels {
|
||||
model := generator.GenerateProcessedModel()
|
||||
models = append(models, model)
|
||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||
|
||||
170
.github/workflows/backend.yml
vendored
170
.github/workflows/backend.yml
vendored
@@ -105,6 +105,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -561,6 +574,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -965,6 +991,32 @@ jobs:
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1108,6 +1160,32 @@ jobs:
|
||||
backend: "stablediffusion-ggml"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-sam3-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1644,6 +1722,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -1842,6 +1946,59 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sam3-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1894,6 +2051,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-sam3-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# whisper
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -34,6 +34,10 @@ jobs:
|
||||
variable: "ACESTEP_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/acestep-cpp/Makefile"
|
||||
- repository: "PABannier/sam3.cpp"
|
||||
variable: "SAM3_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/sam3-cpp/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/gallery-agent.yaml
vendored
2
.github/workflows/gallery-agent.yaml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
OPENAI_MODE: Qwen3.5-2B-GGUF
|
||||
OPENAI_MODEL: Qwen3.5-2B-GGUF
|
||||
OPENAI_BASE_URL: "http://localhost:8080"
|
||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
|
||||
23
.github/workflows/test-extra.yml
vendored
23
.github/workflows/test-extra.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -528,3 +529,25 @@ jobs:
|
||||
- name: Test voxtral
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
||||
tests-kokoros:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.kokoros == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake pkg-config protobuf-compiler clang libclang-dev
|
||||
sudo apt-get install -y espeak-ng libespeak-ng-dev libsonic-dev libpcaudio-dev libopus-dev libssl-dev
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
- name: Build kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros kokoros-grpc
|
||||
- name: Test kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros test
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "docs/themes/hugo-theme-relearn"]
|
||||
path = docs/themes/hugo-theme-relearn
|
||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
||||
[submodule "backend/rust/kokoros/sources/Kokoros"]
|
||||
path = backend/rust/kokoros/sources/Kokoros
|
||||
url = https://github.com/lucasjinreal/Kokoros
|
||||
|
||||
@@ -13,6 +13,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
17
Makefile
17
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -148,7 +148,6 @@ test-models/testmodel.ggml:
|
||||
mkdir -p test-dir
|
||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||
wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||
cp tests/models_fixtures/* test-models
|
||||
|
||||
@@ -429,9 +428,11 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/qwen-asr
|
||||
$(MAKE) -C backend/python/nemo
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
$(MAKE) -C backend/python/faster-whisper
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -449,9 +450,11 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/qwen-asr test
|
||||
$(MAKE) -C backend/python/nemo test
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
$(MAKE) -C backend/python/faster-whisper test
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
@@ -587,6 +590,12 @@ BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
|
||||
# Rust backends
|
||||
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
|
||||
# C++ backends (Go wrapper with purego)
|
||||
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
define docker-build-backend
|
||||
@@ -645,12 +654,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
31
README.md
31
README.md
@@ -42,16 +42,38 @@ Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
|
||||
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
||||
|
||||
## Screenshots
|
||||
|
||||
### Chat, Model gallery
|
||||
## Guided tour
|
||||
|
||||
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||
|
||||
### Agents
|
||||
<details>
|
||||
|
||||
<summary>
|
||||
Click to see more!
|
||||
</summary>
|
||||
|
||||
#### User and auth
|
||||
|
||||
https://github.com/user-attachments/assets/228fa9ad-81a3-4d43-bfb9-31557e14a36c
|
||||
|
||||
#### Agents
|
||||
|
||||
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||
|
||||
#### Usage metrics per user
|
||||
|
||||
https://github.com/user-attachments/assets/cbb03379-23b4-4e3d-bd26-d152f057007f
|
||||
|
||||
#### Fine-tuning and Quantization
|
||||
|
||||
https://github.com/user-attachments/assets/5ba4ace9-d3df-4795-b7d4-b0b404ea71ee
|
||||
|
||||
#### WebRTC
|
||||
|
||||
https://github.com/user-attachments/assets/ed88e34c-fed3-4b83-8a67-4716a9feeb7b
|
||||
|
||||
</details>
|
||||
|
||||
## Quickstart
|
||||
|
||||
### macOS
|
||||
@@ -174,6 +196,7 @@ See the full [Backend & Model Compatibility Table](https://localai.io/model-comp
|
||||
- [Build from source](https://localai.io/basics/build/)
|
||||
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||
|
||||
|
||||
39
backend/Dockerfile.rust
Normal file
39
backend/Dockerfile.rust
Normal file
@@ -0,0 +1,39 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BACKEND=kokoros
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
curl unzip \
|
||||
clang \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
espeak-ng libespeak-ng-dev \
|
||||
libsonic-dev libpcaudio-dev \
|
||||
libopus-dev \
|
||||
protobuf-compiler && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN git config --global --add safe.directory /LocalAI
|
||||
|
||||
RUN make -C /LocalAI/backend/rust/${BACKEND} build
|
||||
|
||||
FROM scratch
|
||||
ARG BACKEND=kokoros
|
||||
|
||||
COPY --from=builder /LocalAI/backend/rust/${BACKEND}/package/. ./
|
||||
@@ -444,6 +444,10 @@ message Message {
|
||||
|
||||
message DetectOptions {
|
||||
string src = 1;
|
||||
string prompt = 2; // Text prompt (for SAM 3 PCS mode)
|
||||
repeated float points = 3; // Point coordinates as [x1, y1, label1, x2, y2, label2, ...] (label: 1=pos, 0=neg)
|
||||
repeated float boxes = 4; // Box coordinates as [x1, y1, x2, y2, ...]
|
||||
float threshold = 5; // Detection confidence threshold
|
||||
}
|
||||
|
||||
message Detection {
|
||||
@@ -453,6 +457,7 @@ message Detection {
|
||||
float height = 4;
|
||||
float confidence = 5;
|
||||
string class_name = 6;
|
||||
bytes mask = 7; // PNG-encoded binary segmentation mask
|
||||
}
|
||||
|
||||
message DetectResponse {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0fcb3760b2b9a3a496ef14621a7e4dad7a8df90f
|
||||
LLAMA_VERSION?=e62fa13c2497b2cd1958cb496e9489e86bbd5182
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -40,45 +40,41 @@ using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
||||
// gRPC bearer token auth for distributed mode.
|
||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
||||
public:
|
||||
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
||||
|
||||
bool IsBlocking() const override { return false; }
|
||||
// Cached auth token — empty means auth is disabled.
|
||||
static std::string g_grpc_auth_token;
|
||||
|
||||
grpc::Status Process(const InputMetadata& auth_metadata,
|
||||
grpc::AuthContext* /*context*/,
|
||||
OutputMetadata* /*consumed_auth_metadata*/,
|
||||
OutputMetadata* /*response_metadata*/) override {
|
||||
auto it = auth_metadata.find("authorization");
|
||||
if (it != auth_metadata.end()) {
|
||||
std::string expected = "Bearer " + token_;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
// Constant-time comparison
|
||||
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string token_;
|
||||
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
// Returns OK when auth is disabled or the token matches.
|
||||
static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||
if (g_grpc_auth_token.empty()) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
auto metadata = context->client_metadata();
|
||||
auto it = metadata.find("authorization");
|
||||
if (it != metadata.end()) {
|
||||
std::string expected = "Bearer " + g_grpc_auth_token;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
if (expected.size() == got.size() &&
|
||||
ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
// END LocalAI
|
||||
|
||||
@@ -288,6 +284,12 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
// Speculative decoding per-request overrides
|
||||
// NDraft maps to speculative.n_max (maximum draft tokens per speculation step)
|
||||
if (predict->ndraft() > 0) {
|
||||
data["speculative.n_max"] = predict->ndraft();
|
||||
}
|
||||
|
||||
// Add the correlationid to json data
|
||||
data["correlation_id"] = predict->correlationid();
|
||||
|
||||
@@ -406,6 +408,16 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->mmproj().empty()) {
|
||||
params.mmproj.path = request->mmproj();
|
||||
}
|
||||
|
||||
// Draft model for speculative decoding
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.mparams_dft.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
params.model_alias.insert(request->modelfile());
|
||||
if (!request->cachetypekey().empty()) {
|
||||
@@ -613,6 +625,48 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (8)
|
||||
}
|
||||
}
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
auto type = common_speculative_type_from_name(optval_str);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_min") || !strcmp(optname, "draft_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_min = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_min") || !strcmp(optname, "draft_p_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_min = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_split")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_split = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_n") || !strcmp(optname, "ngram_size_n")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_n = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_m") || !strcmp(optname, "ngram_size_m")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_m = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_min_hits") || !strcmp(optname, "ngram_min_hits")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_min_hits = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_gpu_layers")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_gpu_layers = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_ctx_size")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_ctx = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -757,13 +811,17 @@ private:
|
||||
public:
|
||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement Health RPC
|
||||
reply->set_message("OK");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(ctx_server, request, params);
|
||||
@@ -962,6 +1020,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -1249,6 +1309,7 @@ public:
|
||||
|
||||
body_json["messages"] = messages_json;
|
||||
body_json["stream"] = true; // PredictStream is always streaming
|
||||
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||
|
||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||
@@ -1553,11 +1614,15 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
// Without this, the PEG parser never produces diffs and the Go side
|
||||
// cannot detect tool calls or separate reasoning from content.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -1582,19 +1647,47 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
std::string completion_text;
|
||||
|
||||
if (res_json.contains("choices")) {
|
||||
// OAI chat format — extract content from choices[0].delta
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & delta = choices[0].value("delta", json::object());
|
||||
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||
completion_text = delta.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = res_json.value("content", "");
|
||||
}
|
||||
|
||||
reply.set_message(completion_text);
|
||||
|
||||
// Token counts: native format has top-level fields,
|
||||
// OAI format has them in "usage" (final chunk only)
|
||||
if (res_json.contains("usage")) {
|
||||
const auto & usage = res_json.at("usage");
|
||||
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||
} else {
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
}
|
||||
|
||||
// Timings: present as top-level "timings" in both formats
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
@@ -1603,6 +1696,12 @@ public:
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Attach chat deltas from the autoparser to a Reply.
|
||||
// When diffs are available, populate ChatDeltas on the reply.
|
||||
// The raw message is always preserved so the Go side can use it
|
||||
// for reasoning extraction and tool call parsing as a fallback
|
||||
// (important in distributed mode where ChatDeltas may not be
|
||||
// the primary parsing path).
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
@@ -1617,12 +1716,23 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Process first result
|
||||
// Process first result.
|
||||
// When TASK_RESPONSE_TYPE_OAI_CHAT is used, the first token may
|
||||
// produce a JSON array with a role-init element followed by the
|
||||
// actual content element. We must only attach chat deltas to the
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
// delta but no content/reasoning diffs of their own).
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -1646,7 +1756,11 @@ public:
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -1665,6 +1779,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2282,11 +2398,13 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -2317,25 +2435,48 @@ public:
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||
std::string completion_text;
|
||||
int32_t tokens_predicted = 0;
|
||||
int32_t tokens_evaluated = 0;
|
||||
|
||||
if (result_json.contains("choices")) {
|
||||
// OAI chat format
|
||||
const auto & choices = result_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
completion_text = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
if (result_json.contains("usage")) {
|
||||
const auto & usage = result_json.at("usage");
|
||||
tokens_predicted = usage.value("completion_tokens", 0);
|
||||
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = result_json.value("content", "");
|
||||
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
}
|
||||
reply->set_message(completion_text);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Timings: present in both formats as a top-level "timings" object
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
reply->set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
@@ -2351,7 +2492,20 @@ public:
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
// Handle both native and OAI chat formats
|
||||
std::string result_content;
|
||||
if (res_json.contains("choices")) {
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
result_content = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result_content = res_json.value("content", "");
|
||||
}
|
||||
arr.push_back(result_content);
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
@@ -2383,6 +2537,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2563,7 +2719,9 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2803,19 +2961,14 @@ int main(int argc, char** argv) {
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
ServerBuilder builder;
|
||||
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
std::shared_ptr<grpc::ServerCredentials> creds;
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
creds->SetAuthMetadataProcessor(
|
||||
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
} else {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
}
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
|
||||
builder.AddListeningPort(server_address, creds);
|
||||
// Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
g_grpc_auth_token = auth_token;
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=6f35c874ee11e86d511b860019b84976f5b52d3a
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
libgosam3*.so
|
||||
sam3-cpp
|
||||
test-models/
|
||||
test-data/
|
||||
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(gosam3 LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# Build ggml as static libraries to avoid runtime .so dependencies
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||
|
||||
set(SAM3_BUILD_EXAMPLES OFF CACHE BOOL "Disable sam3.cpp examples" FORCE)
|
||||
set(SAM3_BUILD_TESTS OFF CACHE BOOL "Disable sam3.cpp tests" FORCE)
|
||||
|
||||
add_subdirectory(./sources/sam3.cpp)
|
||||
|
||||
add_library(gosam3 MODULE gosam3.cpp)
|
||||
target_link_libraries(gosam3 PRIVATE sam3 ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gosam3 PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
target_include_directories(gosam3 PUBLIC
|
||||
sources/sam3.cpp
|
||||
sources/sam3.cpp/ggml/include
|
||||
)
|
||||
|
||||
set_property(TARGET gosam3 PROPERTY CXX_STANDARD 14)
|
||||
set_target_properties(gosam3 PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
122
backend/go/sam3-cpp/Makefile
Normal file
122
backend/go/sam3-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# sam3.cpp
|
||||
SAM3_REPO?=https://github.com/PABannier/sam3.cpp
|
||||
SAM3_VERSION?=01832ef85fcc8eb6488f1d01cd247f07e96ff5a9
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/sam3.cpp:
|
||||
git clone --recursive $(SAM3_REPO) sources/sam3.cpp && \
|
||||
cd sources/sam3.cpp && \
|
||||
git checkout $(SAM3_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgosam3-avx.so libgosam3-avx2.so libgosam3-avx512.so libgosam3-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = libgosam3-fallback.so
|
||||
endif
|
||||
|
||||
sam3-cpp: main.go gosam3.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o sam3-cpp ./
|
||||
|
||||
package: sam3-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgosam3*.so sam3-cpp package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgosam3-avx.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgosam3-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx2.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgosam3-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx512.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgosam3-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgosam3-fallback.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgosam3-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-custom: CMakeLists.txt gosam3.cpp gosam3.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgosam3.so ./$(SO_TARGET)
|
||||
|
||||
all: sam3-cpp package
|
||||
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
#include "sam3.h"
|
||||
#include "gosam3.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
// Static state
|
||||
static std::shared_ptr<sam3_model> g_model;
|
||||
static sam3_state_ptr g_state;
|
||||
static sam3_result g_result;
|
||||
static std::vector<std::vector<unsigned char>> g_mask_pngs;
|
||||
|
||||
// Callback for stbi_write_png_to_mem via stbi_write_png_to_func
|
||||
static void png_write_callback(void *context, void *data, int size) {
|
||||
auto *buf = static_cast<std::vector<unsigned char>*>(context);
|
||||
auto *bytes = static_cast<unsigned char*>(data);
|
||||
buf->insert(buf->end(), bytes, bytes + size);
|
||||
}
|
||||
|
||||
// Encode all masks as PNGs after segmentation
|
||||
static void encode_masks_as_png() {
|
||||
g_mask_pngs.clear();
|
||||
g_mask_pngs.resize(g_result.detections.size());
|
||||
|
||||
for (size_t i = 0; i < g_result.detections.size(); i++) {
|
||||
const auto &mask = g_result.detections[i].mask;
|
||||
if (mask.width > 0 && mask.height > 0 && !mask.data.empty()) {
|
||||
stbi_write_png_to_func(png_write_callback, &g_mask_pngs[i],
|
||||
mask.width, mask.height, 1,
|
||||
mask.data.data(), mask.width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int sam3_cpp_load_model(const char *model_path, int threads) {
|
||||
sam3_params params;
|
||||
params.model_path = model_path;
|
||||
params.n_threads = threads;
|
||||
params.use_gpu = true;
|
||||
|
||||
g_model = sam3_load_model(params);
|
||||
if (!g_model) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load model: %s\n", model_path);
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_state = sam3_create_state(*g_model, params);
|
||||
if (!g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to create state\n");
|
||||
g_model.reset();
|
||||
return 2;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[sam3-cpp] Model loaded: %s (threads=%d)\n", model_path, threads);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_encode_image(const char *image_path) {
|
||||
if (!g_model || !g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Model not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
sam3_image img = sam3_load_image(image_path);
|
||||
if (img.data.empty()) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load image: %s\n", image_path);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (!sam3_encode_image(*g_state, *g_model, img)) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to encode image\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pvs_params pvs_params;
|
||||
|
||||
// Parse points: each triple is [x, y, label]
|
||||
for (int i = 0; i < n_point_triples; i++) {
|
||||
float x = points[i * 3];
|
||||
float y = points[i * 3 + 1];
|
||||
float label = points[i * 3 + 2];
|
||||
sam3_point pt = {x, y};
|
||||
if (label > 0.5f) {
|
||||
pvs_params.pos_points.push_back(pt);
|
||||
} else {
|
||||
pvs_params.neg_points.push_back(pt);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boxes: each quad is [x1, y1, x2, y2], use only first box
|
||||
if (n_box_quads > 0) {
|
||||
pvs_params.box = {boxes[0], boxes[1], boxes[2], boxes[3]};
|
||||
pvs_params.use_box = true;
|
||||
}
|
||||
|
||||
g_result = sam3_segment_pvs(*g_state, *g_model, pvs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// PCS mode requires SAM 3 (full model with text encoder)
|
||||
if (sam3_is_visual_only(*g_model) ||
|
||||
sam3_get_model_type(*g_model) != SAM3_MODEL_SAM3) {
|
||||
fprintf(stderr, "[sam3-cpp] PCS mode requires full SAM 3 model\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pcs_params pcs_params;
|
||||
pcs_params.text_prompt = text_prompt;
|
||||
pcs_params.score_threshold = threshold > 0 ? threshold : 0.5f;
|
||||
|
||||
g_result = sam3_segment_pcs(*g_state, *g_model, pcs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_get_n_detections(void) {
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_x(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_y(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_w(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.x1 - box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_h(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.y1 - box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_score(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].score;
|
||||
}
|
||||
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size) {
|
||||
if (i < 0 || i >= static_cast<int>(g_mask_pngs.size())) return 0;
|
||||
|
||||
const auto &png = g_mask_pngs[i];
|
||||
int size = static_cast<int>(png.size());
|
||||
|
||||
if (buf == nullptr) {
|
||||
return size;
|
||||
}
|
||||
|
||||
int to_copy = size < buf_size ? size : buf_size;
|
||||
memcpy(buf, png.data(), to_copy);
|
||||
return to_copy;
|
||||
}
|
||||
|
||||
void sam3_cpp_free_results(void) {
|
||||
g_result.detections.clear();
|
||||
g_mask_pngs.clear();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
143
backend/go/sam3-cpp/gosam3.go
Normal file
143
backend/go/sam3-cpp/gosam3.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type SAM3 struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int) int
|
||||
CppEncodeImage func(imagePath string) int
|
||||
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
|
||||
CppSegmentPCS func(textPrompt string, threshold float32) int
|
||||
CppGetNDetections func() int
|
||||
CppGetDetectionX func(i int) float32
|
||||
CppGetDetectionY func(i int) float32
|
||||
CppGetDetectionW func(i int) float32
|
||||
CppGetDetectionH func(i int) float32
|
||||
CppGetDetectionScore func(i int) float32
|
||||
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
|
||||
CppFreeResults func()
|
||||
)
|
||||
|
||||
func (s *SAM3) Load(opts *pb.ModelOptions) error {
|
||||
modelFile := opts.ModelFile
|
||||
if modelFile == "" {
|
||||
modelFile = opts.Model
|
||||
}
|
||||
|
||||
var modelPath string
|
||||
if filepath.IsAbs(modelFile) {
|
||||
modelPath = modelFile
|
||||
} else {
|
||||
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
|
||||
threads := int(opts.Threads)
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
}
|
||||
|
||||
ret := CppLoadModel(modelPath, threads)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
// Decode base64 image and write to temp file
|
||||
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "sam3-*.png")
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.Write(imgData); err != nil {
|
||||
tmpFile.Close()
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Encode image
|
||||
ret := CppEncodeImage(tmpFile.Name())
|
||||
if ret != 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
|
||||
}
|
||||
|
||||
threshold := opts.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.5
|
||||
}
|
||||
|
||||
// Determine segmentation mode
|
||||
var nDetections int
|
||||
if opts.Prompt != "" {
|
||||
// Text-prompted segmentation (PCS mode, SAM 3 only)
|
||||
nDetections = CppSegmentPCS(opts.Prompt, threshold)
|
||||
} else {
|
||||
// Point/box-prompted segmentation (PVS mode)
|
||||
var pointsPtr uintptr
|
||||
var boxesPtr uintptr
|
||||
nPointTriples := len(opts.Points) / 3
|
||||
nBoxQuads := len(opts.Boxes) / 4
|
||||
|
||||
if nPointTriples > 0 {
|
||||
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
|
||||
}
|
||||
if nBoxQuads > 0 {
|
||||
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
|
||||
}
|
||||
|
||||
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
|
||||
}
|
||||
|
||||
if nDetections < 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
|
||||
}
|
||||
|
||||
defer CppFreeResults()
|
||||
|
||||
// Build response
|
||||
detections := make([]*pb.Detection, nDetections)
|
||||
for i := 0; i < nDetections; i++ {
|
||||
det := &pb.Detection{
|
||||
X: CppGetDetectionX(i),
|
||||
Y: CppGetDetectionY(i),
|
||||
Width: CppGetDetectionW(i),
|
||||
Height: CppGetDetectionH(i),
|
||||
Confidence: CppGetDetectionScore(i),
|
||||
ClassName: "segment",
|
||||
}
|
||||
|
||||
// Get mask PNG
|
||||
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
|
||||
if maskSize > 0 {
|
||||
maskBuf := make([]byte, maskSize)
|
||||
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
||||
det.Mask = maskBuf
|
||||
}
|
||||
|
||||
detections[i] = det
|
||||
}
|
||||
|
||||
return pb.DetectResponse{
|
||||
Detections: detections,
|
||||
}, nil
|
||||
}
|
||||
51
backend/go/sam3-cpp/gosam3.h
Normal file
51
backend/go/sam3-cpp/gosam3.h
Normal file
@@ -0,0 +1,51 @@
|
||||
#ifndef GOSAM3_H
|
||||
#define GOSAM3_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Load model from file. Returns 0 on success, non-zero on failure.
|
||||
int sam3_cpp_load_model(const char *model_path, int threads);
|
||||
|
||||
// Encode an image from file path. Must be called before segmentation.
|
||||
// Returns 0 on success.
|
||||
int sam3_cpp_encode_image(const char *image_path);
|
||||
|
||||
// Segment with point/box prompts (PVS mode).
|
||||
// points: flat array of [x, y, label] triples (label: 1=positive, 0=negative)
|
||||
// boxes: flat array of [x1, y1, x2, y2] quads
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold);
|
||||
|
||||
// Segment with text prompt (PCS mode, SAM 3 only).
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold);
|
||||
|
||||
// Access detection results (valid after a segment call).
|
||||
int sam3_cpp_get_n_detections(void);
|
||||
|
||||
// Get bounding box for detection i (as x, y, width, height).
|
||||
float sam3_cpp_get_detection_x(int i);
|
||||
float sam3_cpp_get_detection_y(int i);
|
||||
float sam3_cpp_get_detection_w(int i);
|
||||
float sam3_cpp_get_detection_h(int i);
|
||||
|
||||
// Get confidence score for detection i.
|
||||
float sam3_cpp_get_detection_score(int i);
|
||||
|
||||
// Get mask as PNG-encoded bytes.
|
||||
// If buf is NULL, returns the required buffer size.
|
||||
// Otherwise writes up to buf_size bytes and returns bytes written.
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size);
|
||||
|
||||
// Free current detection results.
|
||||
void sam3_cpp_free_results(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GOSAM3_H
|
||||
56
backend/go/sam3-cpp/main.go
Normal file
56
backend/go/sam3-cpp/main.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("SAM3_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgosam3-fallback.so"
|
||||
}
|
||||
|
||||
gosamLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "sam3_cpp_load_model"},
|
||||
{&CppEncodeImage, "sam3_cpp_encode_image"},
|
||||
{&CppSegmentPVS, "sam3_cpp_segment_pvs"},
|
||||
{&CppSegmentPCS, "sam3_cpp_segment_pcs"},
|
||||
{&CppGetNDetections, "sam3_cpp_get_n_detections"},
|
||||
{&CppGetDetectionX, "sam3_cpp_get_detection_x"},
|
||||
{&CppGetDetectionY, "sam3_cpp_get_detection_y"},
|
||||
{&CppGetDetectionW, "sam3_cpp_get_detection_w"},
|
||||
{&CppGetDetectionH, "sam3_cpp_get_detection_h"},
|
||||
{&CppGetDetectionScore, "sam3_cpp_get_detection_score"},
|
||||
{&CppGetDetectionMaskPNG, "sam3_cpp_get_detection_mask_png"},
|
||||
{&CppFreeResults, "sam3_cpp_free_results"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosamLib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &SAM3{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
59
backend/go/sam3-cpp/package.sh
Executable file
59
backend/go/sam3-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/libgosam3-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/sam3-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/sam3-cpp/run.sh
Executable file
52
backend/go/sam3-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgosam3-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export SAM3_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/sam3-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/sam3-cpp "$@"
|
||||
50
backend/go/sam3-cpp/test.sh
Executable file
50
backend/go/sam3-cpp/test.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running sam3-cpp backend tests..."
|
||||
|
||||
# The test requires a SAM model in GGML format.
|
||||
# Uses EdgeTAM Q4_0 (~15MB) for fast CI testing.
|
||||
SAM3_MODEL_DIR="${SAM3_MODEL_DIR:-$CURDIR/test-models}"
|
||||
SAM3_MODEL_FILE="${SAM3_MODEL_FILE:-edgetam_q4_0.ggml}"
|
||||
SAM3_MODEL_URL="${SAM3_MODEL_URL:-https://huggingface.co/PABannier/sam3.cpp/resolve/main/edgetam_q4_0.ggml}"
|
||||
|
||||
# Download model if not present
|
||||
if [ ! -f "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" ]; then
|
||||
echo "Downloading EdgeTAM Q4_0 model for testing..."
|
||||
mkdir -p "$SAM3_MODEL_DIR"
|
||||
curl -L -o "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" "$SAM3_MODEL_URL" --progress-bar
|
||||
echo "Model downloaded."
|
||||
fi
|
||||
|
||||
# Create a test image (4x4 red pixel PNG) using base64
|
||||
# This is a minimal valid PNG for testing the pipeline
|
||||
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||
mkdir -p "$TEST_IMAGE_DIR"
|
||||
|
||||
# Generate a simple test image using Python if available, otherwise use a pre-encoded one
|
||||
if command -v python3 &> /dev/null; then
|
||||
python3 -c "
|
||||
import struct, zlib, base64
|
||||
def create_png(width, height, r, g, b):
|
||||
raw = b''
|
||||
for y in range(height):
|
||||
raw += b'\x00' # filter byte
|
||||
for x in range(width):
|
||||
raw += bytes([r, g, b])
|
||||
def chunk(ctype, data):
|
||||
c = ctype + data
|
||||
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
|
||||
ihdr = struct.pack('>IIBBBBB', width, height, 8, 2, 0, 0, 0)
|
||||
return b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', zlib.compress(raw)) + chunk(b'IEND', b'')
|
||||
with open('$TEST_IMAGE_DIR/test.png', 'wb') as f:
|
||||
f.write(create_png(64, 64, 255, 0, 0))
|
||||
"
|
||||
echo "Test image created."
|
||||
fi
|
||||
|
||||
echo "sam3-cpp test setup complete."
|
||||
echo "Model: $SAM3_MODEL_DIR/$SAM3_MODEL_FILE"
|
||||
echo "Note: Full integration tests run via the LocalAI test-extra target."
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=09b12d5f6d51d862749e8e0ee8baac8f012089e2
|
||||
STABLEDIFFUSION_GGML_VERSION?=e8323cabb0e4511ba18a50b1cb34cf1f87fc71ef
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -125,6 +125,31 @@
|
||||
nvidia-cuda-13: "cuda13-rfdetr"
|
||||
nvidia-cuda-12: "cuda12-rfdetr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
||||
- &sam3cpp
|
||||
name: "sam3-cpp"
|
||||
alias: "sam3-cpp"
|
||||
license: mit
|
||||
description: |
|
||||
Segment Anything Model (SAM 3/2/EdgeTAM) in C/C++ using GGML.
|
||||
Supports text-prompted and point/box-prompted image segmentation.
|
||||
urls:
|
||||
- https://github.com/PABannier/sam3.cpp
|
||||
tags:
|
||||
- image-segmentation
|
||||
- object-detection
|
||||
- sam3
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp"
|
||||
nvidia: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
intel: "intel-sycl-f32-sam3-cpp"
|
||||
vulkan: "vulkan-sam3-cpp"
|
||||
- &vllm
|
||||
name: "vllm"
|
||||
license: apache-2.0
|
||||
@@ -400,12 +425,15 @@
|
||||
license: MIT
|
||||
name: "faster-whisper"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper"
|
||||
nvidia: "cuda12-faster-whisper"
|
||||
intel: "intel-faster-whisper"
|
||||
amd: "rocm-faster-whisper"
|
||||
metal: "metal-faster-whisper"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper"
|
||||
nvidia-cuda-12: "cuda12-faster-whisper"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-faster-whisper"
|
||||
- &moonshine
|
||||
description: |
|
||||
Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
|
||||
@@ -438,6 +466,7 @@
|
||||
- whisperx
|
||||
license: BSD-4-Clause
|
||||
name: "whisperx"
|
||||
alias: "whisperx"
|
||||
capabilities:
|
||||
nvidia: "cuda12-whisperx"
|
||||
amd: "rocm-whisperx"
|
||||
@@ -445,6 +474,8 @@
|
||||
default: "cpu-whisperx"
|
||||
nvidia-cuda-13: "cuda13-whisperx"
|
||||
nvidia-cuda-12: "cuda12-whisperx"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisperx"
|
||||
- &kokoro
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
@@ -468,6 +499,26 @@
|
||||
nvidia-cuda-13: "cuda13-kokoro"
|
||||
nvidia-cuda-12: "cuda12-kokoro"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
|
||||
- &kokoros
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
Kokoros is a pure Rust TTS backend using the Kokoro ONNX model (82M parameters).
|
||||
It provides fast, high-quality text-to-speech with streaming support, built on
|
||||
ONNX Runtime for efficient CPU inference. Supports English, Japanese, Mandarin
|
||||
Chinese, and German.
|
||||
urls:
|
||||
- https://huggingface.co/hexgrad/Kokoro-82M
|
||||
- https://github.com/lucasjinreal/Kokoros
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- Rust
|
||||
- ONNX
|
||||
license: apache-2.0
|
||||
alias: "kokoros"
|
||||
name: "kokoros"
|
||||
capabilities:
|
||||
default: "cpu-kokoros"
|
||||
- &coqui
|
||||
urls:
|
||||
- https://github.com/idiap/coqui-ai-TTS
|
||||
@@ -1602,6 +1653,89 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-rfdetr
|
||||
## sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "sam3-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp-development"
|
||||
nvidia: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
intel: "intel-sycl-f32-sam3-cpp-development"
|
||||
vulkan: "vulkan-sam3-cpp-development"
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
|
||||
## Rerankers
|
||||
- !!merge <<: *rerankers
|
||||
name: "rerankers-development"
|
||||
@@ -2042,15 +2176,32 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-kokoro
|
||||
## kokoros (Rust)
|
||||
- !!merge <<: *kokoros
|
||||
name: "kokoros-development"
|
||||
capabilities:
|
||||
default: "cpu-kokoros-development"
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-kokoros
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-kokoros
|
||||
## faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "faster-whisper-development"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper-development"
|
||||
nvidia: "cuda12-faster-whisper-development"
|
||||
intel: "intel-faster-whisper-development"
|
||||
amd: "rocm-faster-whisper-development"
|
||||
metal: "metal-faster-whisper-development"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
|
||||
@@ -2091,6 +2242,36 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "rocm-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-faster-whisper
|
||||
## moonshine
|
||||
- !!merge <<: *moonshine
|
||||
name: "moonshine-development"
|
||||
@@ -2149,6 +2330,7 @@
|
||||
default: "cpu-whisperx-development"
|
||||
nvidia-cuda-13: "cuda13-whisperx-development"
|
||||
nvidia-cuda-12: "cuda12-whisperx-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx-development"
|
||||
- !!merge <<: *whisperx
|
||||
name: "cpu-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx"
|
||||
@@ -2199,6 +2381,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-whisperx
|
||||
## coqui
|
||||
|
||||
- !!merge <<: *coqui
|
||||
|
||||
@@ -16,4 +16,14 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
faster-whisper
|
||||
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
faster-whisper
|
||||
@@ -147,7 +147,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.language and request.language.strip():
|
||||
language = request.language.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language)
|
||||
context = ""
|
||||
if request.prompt and request.prompt.strip():
|
||||
context = request.prompt.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language, context=context)
|
||||
|
||||
if not results:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
@@ -8,8 +8,21 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
# --index-strategy is a uv-only flag; skip it when using pip
|
||||
if [ "x${USE_PIP}" != "xtrue" ]; then
|
||||
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
||||
fi
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||
3
backend/rust/kokoros/.gitignore
vendored
Normal file
3
backend/rust/kokoros/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/target/
|
||||
/proto/
|
||||
/package/
|
||||
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
backend/rust/kokoros/Cargo.toml
Normal file
26
backend/rust/kokoros/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "kokoros-grpc"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "kokoros-grpc"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
kokoros = { path = "sources/Kokoros/kokoros" }
|
||||
|
||||
tonic = "0.13"
|
||||
prost = "0.13"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.13"
|
||||
|
||||
[features]
|
||||
default = ["cpu"]
|
||||
cpu = ["kokoros/cpu"]
|
||||
25
backend/rust/kokoros/Makefile
Normal file
25
backend/rust/kokoros/Makefile
Normal file
@@ -0,0 +1,25 @@
|
||||
CURRENT_DIR=$(abspath ./)
|
||||
|
||||
.PHONY: kokoros-grpc
|
||||
kokoros-grpc:
|
||||
mkdir -p $(CURRENT_DIR)/proto
|
||||
cp $(CURRENT_DIR)/../../backend.proto $(CURRENT_DIR)/proto/backend.proto
|
||||
cd $(CURRENT_DIR) && \
|
||||
BACKEND_PROTO_PATH=$(CURRENT_DIR)/proto/backend.proto \
|
||||
cargo build --release
|
||||
|
||||
.PHONY: package
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
.PHONY: test
|
||||
test: kokoros-grpc
|
||||
cd $(CURRENT_DIR) && cargo test
|
||||
|
||||
.PHONY: build
|
||||
build: kokoros-grpc package
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
cargo clean
|
||||
rm -rf package proto
|
||||
15
backend/rust/kokoros/build.rs
Normal file
15
backend/rust/kokoros/build.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let proto_path = std::env::var("BACKEND_PROTO_PATH")
|
||||
.unwrap_or_else(|_| "proto/backend.proto".to_string());
|
||||
|
||||
let proto_dir = std::path::Path::new(&proto_path)
|
||||
.parent()
|
||||
.unwrap_or(std::path::Path::new("."));
|
||||
|
||||
tonic_build::configure()
|
||||
.build_server(true)
|
||||
.build_client(false)
|
||||
.compile_protos(&[&proto_path], &[proto_dir])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
42
backend/rust/kokoros/package.sh
Normal file
42
backend/rust/kokoros/package.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
# Copy the binary and run script
|
||||
cp -avf $CURDIR/target/release/kokoros-grpc $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
chmod +x $CURDIR/package/run.sh
|
||||
|
||||
# Copy espeak-ng data
|
||||
if [ -d "/usr/share/espeak-ng-data" ]; then
|
||||
cp -rf /usr/share/espeak-ng-data $CURDIR/package/
|
||||
elif [ -d "/usr/lib/x86_64-linux-gnu/espeak-ng-data" ]; then
|
||||
cp -rf /usr/lib/x86_64-linux-gnu/espeak-ng-data $CURDIR/package/
|
||||
fi
|
||||
|
||||
# Bundle all dynamic library dependencies
|
||||
echo "Bundling dynamic library dependencies..."
|
||||
ldd $CURDIR/target/release/kokoros-grpc | grep "=>" | awk '{print $3}' | while read lib; do
|
||||
if [ -n "$lib" ] && [ -f "$lib" ]; then
|
||||
cp -avfL "$lib" $CURDIR/package/lib/
|
||||
fi
|
||||
done
|
||||
|
||||
# Copy CA certificates for HTTPS (needed for model auto-download)
|
||||
if [ -d "/etc/ssl/certs" ]; then
|
||||
mkdir -p $CURDIR/package/etc/ssl
|
||||
cp -rf /etc/ssl/certs $CURDIR/package/etc/ssl/
|
||||
fi
|
||||
|
||||
# Copy the dynamic linker
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
23
backend/rust/kokoros/run.sh
Executable file
23
backend/rust/kokoros/run.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:${LD_LIBRARY_PATH:-}
|
||||
|
||||
# SSL certificates for model auto-download
|
||||
if [ -d "$CURDIR/etc/ssl/certs" ]; then
|
||||
export SSL_CERT_DIR=$CURDIR/etc/ssl/certs
|
||||
fi
|
||||
|
||||
# espeak-ng data directory
|
||||
if [ -d "$CURDIR/espeak-ng-data" ]; then
|
||||
export ESPEAK_NG_DATA=$CURDIR/espeak-ng-data
|
||||
fi
|
||||
|
||||
# Use bundled ld.so if present (portability)
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
exec $CURDIR/lib/ld.so $CURDIR/kokoros-grpc "$@"
|
||||
fi
|
||||
|
||||
exec $CURDIR/kokoros-grpc "$@"
|
||||
1
backend/rust/kokoros/sources/Kokoros
Submodule
1
backend/rust/kokoros/sources/Kokoros
Submodule
Submodule backend/rust/kokoros/sources/Kokoros added at 7089168f0c
26
backend/rust/kokoros/src/auth.rs
Normal file
26
backend/rust/kokoros/src/auth.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use tonic::{Request, Status};
|
||||
|
||||
/// Returns an interceptor function if LOCALAI_GRPC_AUTH_TOKEN is set.
|
||||
pub fn make_auth_interceptor(
|
||||
) -> Option<impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone> {
|
||||
let token = std::env::var("LOCALAI_GRPC_AUTH_TOKEN").ok()?;
|
||||
if token.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let expected = format!("Bearer {}", token);
|
||||
Some(
|
||||
move |req: Request<()>| -> Result<Request<()>, Status> {
|
||||
let meta = req.metadata();
|
||||
match meta.get("authorization") {
|
||||
Some(val) => {
|
||||
if val.as_bytes() == expected.as_bytes() {
|
||||
Ok(req)
|
||||
} else {
|
||||
Err(Status::unauthenticated("invalid token"))
|
||||
}
|
||||
}
|
||||
None => Err(Status::unauthenticated("missing authorization")),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
53
backend/rust/kokoros/src/main.rs
Normal file
53
backend/rust/kokoros/src/main.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use clap::Parser;
|
||||
use tonic::transport::Server;
|
||||
|
||||
mod auth;
|
||||
mod service;
|
||||
|
||||
pub mod backend {
|
||||
tonic::include_proto!("backend");
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "kokoros-grpc")]
|
||||
struct Cli {
|
||||
/// gRPC listen address (host:port)
|
||||
#[arg(long, default_value = "localhost:50051")]
|
||||
addr: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_ansi(false)
|
||||
.without_time()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
let addr = cli.addr.parse()?;
|
||||
|
||||
tracing::info!("Starting kokoros gRPC server on {}", addr);
|
||||
|
||||
let mut builder = Server::builder();
|
||||
|
||||
if let Some(interceptor) = auth::make_auth_interceptor() {
|
||||
tracing::info!("Bearer token authentication enabled");
|
||||
let svc = backend::backend_server::BackendServer::with_interceptor(
|
||||
service::KokorosService::default(),
|
||||
interceptor,
|
||||
);
|
||||
builder.add_service(svc).serve(addr).await?;
|
||||
} else {
|
||||
let svc = backend::backend_server::BackendServer::new(service::KokorosService::default())
|
||||
.max_decoding_message_size(50 * 1024 * 1024)
|
||||
.max_encoding_message_size(50 * 1024 * 1024);
|
||||
builder.add_service(svc).serve(addr).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
652
backend/rust/kokoros/src/service.rs
Normal file
652
backend/rust/kokoros/src/service.rs
Normal file
@@ -0,0 +1,652 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use kokoros::tts::koko::TTSKoko;
|
||||
|
||||
use crate::backend;
|
||||
use crate::backend::backend_server::Backend;
|
||||
|
||||
/// Write f32 samples as a standard 44-byte PCM 16-bit WAV file.
|
||||
/// LocalAI's audio pipeline assumes this exact header layout.
|
||||
fn write_pcm16_wav(
|
||||
path: &str,
|
||||
samples: &[f32],
|
||||
sample_rate: u32,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
let num_samples = samples.len() as u32;
|
||||
let data_size = num_samples * 2; // 16-bit = 2 bytes per sample
|
||||
let file_size = 36 + data_size;
|
||||
|
||||
let mut f = File::create(path)?;
|
||||
|
||||
// RIFF header
|
||||
f.write_all(b"RIFF")?;
|
||||
f.write_all(&file_size.to_le_bytes())?;
|
||||
f.write_all(b"WAVE")?;
|
||||
|
||||
// fmt chunk — standard 16-byte PCM format
|
||||
f.write_all(b"fmt ")?;
|
||||
f.write_all(&16u32.to_le_bytes())?; // chunk size
|
||||
f.write_all(&1u16.to_le_bytes())?; // audio format = PCM
|
||||
f.write_all(&1u16.to_le_bytes())?; // channels = mono
|
||||
f.write_all(&sample_rate.to_le_bytes())?;
|
||||
f.write_all(&(sample_rate * 2).to_le_bytes())?; // byte rate
|
||||
f.write_all(&2u16.to_le_bytes())?; // block align
|
||||
f.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// data chunk
|
||||
f.write_all(b"data")?;
|
||||
f.write_all(&data_size.to_le_bytes())?;
|
||||
|
||||
for &s in samples {
|
||||
let clamped = s.clamp(-1.0, 1.0);
|
||||
let pcm = (clamped * 32767.0) as i16;
|
||||
f.write_all(&pcm.to_le_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct KokorosService {
|
||||
tts: Arc<TokioMutex<Option<TTSKoko>>>,
|
||||
language: Arc<Mutex<String>>,
|
||||
speed: Arc<Mutex<f32>>,
|
||||
}
|
||||
|
||||
impl Default for KokorosService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tts: Arc::new(TokioMutex::new(None)),
|
||||
language: Arc::new(Mutex::new("en-us".to_string())),
|
||||
speed: Arc::new(Mutex::new(1.0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl Backend for KokorosService {
|
||||
async fn health(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::Reply>, Status> {
|
||||
Ok(Response::new(backend::Reply {
|
||||
message: b"OK".to_vec(),
|
||||
..Default::default()
|
||||
}))
|
||||
}
|
||||
|
||||
async fn load_model(
|
||||
&self,
|
||||
req: Request<backend::ModelOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
let opts = req.into_inner();
|
||||
|
||||
// Model path: join ModelPath + Model, or just Model
|
||||
let model_path = if !opts.model_path.is_empty() && !opts.model.is_empty() {
|
||||
format!("{}/{}", opts.model_path, opts.model)
|
||||
} else if !opts.model.is_empty() {
|
||||
opts.model.clone()
|
||||
} else {
|
||||
"checkpoints/kokoro-v1.0.onnx".to_string()
|
||||
};
|
||||
|
||||
// Voices data path from AudioPath, or derive from model dir
|
||||
let voices_path = if !opts.audio_path.is_empty() {
|
||||
opts.audio_path.clone()
|
||||
} else {
|
||||
let model_dir = std::path::Path::new(&model_path)
|
||||
.parent()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| ".".to_string());
|
||||
format!("{}/voices-v1.0.bin", model_dir)
|
||||
};
|
||||
|
||||
// Parse options (key:value pairs)
|
||||
for opt in &opts.options {
|
||||
if let Some((key, value)) = opt.split_once(':') {
|
||||
match key {
|
||||
"lang_code" => *self.language.lock().unwrap() = value.to_string(),
|
||||
"speed" => {
|
||||
if let Ok(s) = value.parse::<f32>() {
|
||||
*self.speed.lock().unwrap() = s;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Loading Kokoros model from: {}", model_path);
|
||||
tracing::info!("Loading voices from: {}", voices_path);
|
||||
tracing::info!("Language: {}", self.language.lock().unwrap());
|
||||
|
||||
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||
*self.tts.lock().await = Some(tts);
|
||||
|
||||
tracing::info!("Kokoros TTS model loaded successfully");
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: "Kokoros TTS model loaded".into(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn tts(
|
||||
&self,
|
||||
req: Request<backend::TtsRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
let req = req.into_inner();
|
||||
let tts_guard = self.tts.lock().await;
|
||||
let tts = tts_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?;
|
||||
|
||||
let voice = if req.voice.is_empty() {
|
||||
"af_heart"
|
||||
} else {
|
||||
&req.voice
|
||||
};
|
||||
let lang = req
|
||||
.language
|
||||
.filter(|l| !l.is_empty())
|
||||
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||
let speed = *self.speed.lock().unwrap();
|
||||
|
||||
tracing::info!(
|
||||
text = req.text,
|
||||
voice = voice,
|
||||
lang = lang.as_str(),
|
||||
dst = req.dst,
|
||||
"TTS request received"
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
match tts.tts_raw_audio(&req.text, &lang, voice, speed, None, None, None, None) {
|
||||
Ok(samples) => {
|
||||
let duration_secs = samples.len() as f64 / 24000.0;
|
||||
tracing::info!(
|
||||
num_samples = samples.len(),
|
||||
audio_duration = format!("{:.2}s", duration_secs),
|
||||
inference_time = format!("{:.2}s", start.elapsed().as_secs_f64()),
|
||||
dst = req.dst,
|
||||
"TTS inference complete"
|
||||
);
|
||||
if let Err(e) = write_pcm16_wav(&req.dst, &samples, 24000) {
|
||||
tracing::error!("Failed to write WAV to {}: {}", req.dst, e);
|
||||
return Ok(Response::new(backend::Result {
|
||||
success: false,
|
||||
message: format!("Failed to write WAV: {}", e),
|
||||
}));
|
||||
}
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: String::new(),
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("TTS error: {}", e);
|
||||
Ok(Response::new(backend::Result {
|
||||
success: false,
|
||||
message: format!("TTS error: {}", e),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TTSStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||
|
||||
async fn tts_stream(
|
||||
&self,
|
||||
req: Request<backend::TtsRequest>,
|
||||
) -> Result<Response<Self::TTSStreamStream>, Status> {
|
||||
let req = req.into_inner();
|
||||
let tts_guard = self.tts.lock().await;
|
||||
let tts = tts_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?
|
||||
.clone();
|
||||
|
||||
let voice = if req.voice.is_empty() {
|
||||
"af_heart".to_string()
|
||||
} else {
|
||||
req.voice
|
||||
};
|
||||
let lang = req
|
||||
.language
|
||||
.filter(|l| !l.is_empty())
|
||||
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||
let speed = *self.speed.lock().unwrap();
|
||||
let text = req.text;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(32);
|
||||
|
||||
// Send sample rate info as first message
|
||||
let tx_clone = tx.clone();
|
||||
let _ = tx_clone
|
||||
.send(Ok(backend::Reply {
|
||||
message: br#"{"sample_rate":24000}"#.to_vec(),
|
||||
..Default::default()
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = tts.tts_raw_audio_streaming(
|
||||
&text,
|
||||
&lang,
|
||||
&voice,
|
||||
speed,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|audio_chunk: Vec<f32>| -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Convert f32 PCM to 16-bit PCM bytes (what LocalAI expects for streaming)
|
||||
let bytes: Vec<u8> = audio_chunk
|
||||
.iter()
|
||||
.flat_map(|&s| {
|
||||
let clamped = s.clamp(-1.0, 1.0);
|
||||
let i16_val = (clamped * 32767.0) as i16;
|
||||
i16_val.to_le_bytes()
|
||||
})
|
||||
.collect();
|
||||
tx.blocking_send(Ok(backend::Reply {
|
||||
audio: bytes,
|
||||
..Default::default()
|
||||
}))
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
|
||||
},
|
||||
);
|
||||
if let Err(e) = result {
|
||||
tracing::error!("TTSStream error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn status(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::StatusResponse>, Status> {
|
||||
let tts = self.tts.lock().await;
|
||||
let state = if tts.is_some() {
|
||||
backend::status_response::State::Ready as i32
|
||||
} else {
|
||||
backend::status_response::State::Uninitialized as i32
|
||||
};
|
||||
Ok(Response::new(backend::StatusResponse {
|
||||
state,
|
||||
memory: None,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn free(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
*self.tts.lock().await = None;
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: "Model freed".into(),
|
||||
}))
|
||||
}
|
||||
|
||||
// --- Unimplemented RPCs ---
|
||||
|
||||
async fn predict(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::Reply>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type PredictStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||
|
||||
async fn predict_stream(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<Self::PredictStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn embedding(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::EmbeddingResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn generate_image(
|
||||
&self,
|
||||
_: Request<backend::GenerateImageRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn generate_video(
|
||||
&self,
|
||||
_: Request<backend::GenerateVideoRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_transcription(
|
||||
&self,
|
||||
_: Request<backend::TranscriptRequest>,
|
||||
) -> Result<Response<backend::TranscriptResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn tokenize_string(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::TokenizationResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn detect(
|
||||
&self,
|
||||
_: Request<backend::DetectOptions>,
|
||||
) -> Result<Response<backend::DetectResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_set(
|
||||
&self,
|
||||
_: Request<backend::StoresSetOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_delete(
|
||||
&self,
|
||||
_: Request<backend::StoresDeleteOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_get(
|
||||
&self,
|
||||
_: Request<backend::StoresGetOptions>,
|
||||
) -> Result<Response<backend::StoresGetResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_find(
|
||||
&self,
|
||||
_: Request<backend::StoresFindOptions>,
|
||||
) -> Result<Response<backend::StoresFindResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn rerank(
|
||||
&self,
|
||||
_: Request<backend::RerankRequest>,
|
||||
) -> Result<Response<backend::RerankResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn get_metrics(
|
||||
&self,
|
||||
_: Request<backend::MetricsRequest>,
|
||||
) -> Result<Response<backend::MetricsResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn vad(
|
||||
&self,
|
||||
_: Request<backend::VadRequest>,
|
||||
) -> Result<Response<backend::VadResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_encode(
|
||||
&self,
|
||||
_: Request<backend::AudioEncodeRequest>,
|
||||
) -> Result<Response<backend::AudioEncodeResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_decode(
|
||||
&self,
|
||||
_: Request<backend::AudioDecodeRequest>,
|
||||
) -> Result<Response<backend::AudioDecodeResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn model_metadata(
|
||||
&self,
|
||||
_: Request<backend::ModelOptions>,
|
||||
) -> Result<Response<backend::ModelMetadataResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn start_fine_tune(
|
||||
&self,
|
||||
_: Request<backend::FineTuneRequest>,
|
||||
) -> Result<Response<backend::FineTuneJobResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type FineTuneProgressStream = ReceiverStream<Result<backend::FineTuneProgressUpdate, Status>>;
|
||||
|
||||
async fn fine_tune_progress(
|
||||
&self,
|
||||
_: Request<backend::FineTuneProgressRequest>,
|
||||
) -> Result<Response<Self::FineTuneProgressStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stop_fine_tune(
|
||||
&self,
|
||||
_: Request<backend::FineTuneStopRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn list_checkpoints(
|
||||
&self,
|
||||
_: Request<backend::ListCheckpointsRequest>,
|
||||
) -> Result<Response<backend::ListCheckpointsResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn export_model(
|
||||
&self,
|
||||
_: Request<backend::ExportModelRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn start_quantization(
|
||||
&self,
|
||||
_: Request<backend::QuantizationRequest>,
|
||||
) -> Result<Response<backend::QuantizationJobResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type QuantizationProgressStream =
|
||||
ReceiverStream<Result<backend::QuantizationProgressUpdate, Status>>;
|
||||
|
||||
async fn quantization_progress(
|
||||
&self,
|
||||
_: Request<backend::QuantizationProgressRequest>,
|
||||
) -> Result<Response<Self::QuantizationProgressStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stop_quantization(
|
||||
&self,
|
||||
_: Request<backend::QuantizationStopRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn wav_header_is_standard_pcm16() {
|
||||
let samples = vec![0.0f32, 0.5, -0.5, 1.0, -1.0];
|
||||
let path = std::env::temp_dir().join("kokoros_test.wav");
|
||||
let path_str = path.to_str().unwrap();
|
||||
|
||||
write_pcm16_wav(path_str, &samples, 24000).unwrap();
|
||||
|
||||
let data = std::fs::read(&path).unwrap();
|
||||
std::fs::remove_file(&path).unwrap();
|
||||
|
||||
// Must be exactly 44-byte header + data
|
||||
assert_eq!(data.len(), 44 + samples.len() * 2);
|
||||
|
||||
// RIFF header
|
||||
assert_eq!(&data[0..4], b"RIFF");
|
||||
assert_eq!(&data[8..12], b"WAVE");
|
||||
|
||||
// fmt chunk: 16 bytes, format=1 (PCM), channels=1, 16-bit
|
||||
assert_eq!(&data[12..16], b"fmt ");
|
||||
assert_eq!(u32::from_le_bytes(data[16..20].try_into().unwrap()), 16); // chunk size
|
||||
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM format
|
||||
assert_eq!(u16::from_le_bytes(data[22..24].try_into().unwrap()), 1); // mono
|
||||
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // bits per sample
|
||||
|
||||
// data chunk
|
||||
assert_eq!(&data[36..40], b"data");
|
||||
assert_eq!(
|
||||
u32::from_le_bytes(data[40..44].try_into().unwrap()),
|
||||
(samples.len() * 2) as u32
|
||||
);
|
||||
|
||||
// Verify sample values: 0.5 -> 16383, -0.5 -> -16383, 1.0 -> 32767, -1.0 -> -32767
|
||||
let s1 = i16::from_le_bytes(data[46..48].try_into().unwrap());
|
||||
assert_eq!(s1, 16383); // 0.5 * 32767
|
||||
let s3 = i16::from_le_bytes(data[50..52].try_into().unwrap());
|
||||
assert_eq!(s3, 32767); // 1.0 clamped
|
||||
let s4 = i16::from_le_bytes(data[52..54].try_into().unwrap());
|
||||
assert_eq!(s4, -32767); // -1.0 clamped
|
||||
}
|
||||
|
||||
/// Integration test: runs actual TTS inference and validates the output audio.
|
||||
/// Skipped unless KOKOROS_MODEL_PATH is set to a directory containing
|
||||
/// kokoro-v1.0.onnx and voices-v1.0.bin.
|
||||
#[tokio::test]
|
||||
async fn tts_produces_valid_speech() {
|
||||
let model_dir = match std::env::var("KOKOROS_MODEL_PATH") {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
eprintln!("KOKOROS_MODEL_PATH not set, skipping integration test");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let model_path = format!("{}/kokoro-v1.0.onnx", model_dir);
|
||||
let voices_path = format!("{}/voices-v1.0.bin", model_dir);
|
||||
|
||||
if !std::path::Path::new(&model_path).exists() {
|
||||
eprintln!("Model file not found at {}, skipping", model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||
|
||||
let input_text = "Hello world, this is a test of speech synthesis.";
|
||||
let out_path = std::env::temp_dir().join("kokoros_integration_test.wav");
|
||||
let out_str = out_path.to_str().unwrap();
|
||||
|
||||
let samples = tts
|
||||
.tts_raw_audio(input_text, "en-us", "af_heart", 1.0, None, None, None, None)
|
||||
.expect("tts_raw_audio failed");
|
||||
|
||||
write_pcm16_wav(out_str, &samples, 24000).unwrap();
|
||||
|
||||
let data = std::fs::read(&out_path).unwrap();
|
||||
std::fs::remove_file(&out_path).unwrap();
|
||||
|
||||
// --- WAV header sanity ---
|
||||
assert_eq!(&data[0..4], b"RIFF");
|
||||
assert_eq!(&data[8..12], b"WAVE");
|
||||
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM
|
||||
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // 16-bit
|
||||
|
||||
let num_samples = samples.len();
|
||||
let duration_secs = num_samples as f64 / 24000.0;
|
||||
|
||||
// --- Duration check ---
|
||||
// ~10 words should produce roughly 2-8 seconds of speech
|
||||
assert!(
|
||||
duration_secs > 1.0,
|
||||
"Audio too short: {:.2}s for {} words",
|
||||
duration_secs,
|
||||
input_text.split_whitespace().count()
|
||||
);
|
||||
assert!(
|
||||
duration_secs < 15.0,
|
||||
"Audio too long: {:.2}s for {} words",
|
||||
duration_secs,
|
||||
input_text.split_whitespace().count()
|
||||
);
|
||||
|
||||
// --- Energy check: not silence ---
|
||||
let rms = (samples.iter().map(|s| s * s).sum::<f32>() / num_samples as f32).sqrt();
|
||||
assert!(
|
||||
rms > 0.01,
|
||||
"Audio is near-silence: RMS = {:.6}",
|
||||
rms
|
||||
);
|
||||
|
||||
// --- Not clipped/saturated: should have dynamic range ---
|
||||
let max_abs = samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
max_abs < 1.0,
|
||||
"Audio is fully saturated (max |sample| = {:.4})",
|
||||
max_abs
|
||||
);
|
||||
assert!(
|
||||
max_abs > 0.05,
|
||||
"Audio has very low amplitude (max |sample| = {:.4})",
|
||||
max_abs
|
||||
);
|
||||
|
||||
// --- Speech-like spectral check ---
|
||||
// Speech should have significant energy variation (not white noise or DC).
|
||||
// Check that the signal has zero-crossings in a speech-like range (roughly
|
||||
// 50-400 crossings per 24000 samples = 100-8000 Hz fundamental range).
|
||||
let zero_crossings: usize = samples
|
||||
.windows(2)
|
||||
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
|
||||
.count();
|
||||
let crossings_per_sec = zero_crossings as f64 / duration_secs;
|
||||
// White noise at 24kHz would have ~12000 crossings/sec.
|
||||
// Speech is typically 100-4000 crossings/sec.
|
||||
assert!(
|
||||
crossings_per_sec < 10000.0,
|
||||
"Too many zero crossings ({:.0}/s) — likely noise, not speech",
|
||||
crossings_per_sec
|
||||
);
|
||||
assert!(
|
||||
crossings_per_sec > 50.0,
|
||||
"Too few zero crossings ({:.0}/s) — likely DC or silence, not speech",
|
||||
crossings_per_sec
|
||||
);
|
||||
|
||||
eprintln!(
|
||||
"Integration test passed: duration={:.2}s, rms={:.4}, max={:.4}, zero_crossings={:.0}/s",
|
||||
duration_secs, rms, max_abs, crossings_per_sec
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,9 @@ type Application struct {
|
||||
|
||||
// Distributed mode services (nil when not in distributed mode)
|
||||
distributed *DistributedServices
|
||||
|
||||
// Upgrade checker (background service for detecting backend upgrades)
|
||||
upgradeChecker *UpgradeChecker
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -79,6 +82,19 @@ func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
||||
return a.agentJobService
|
||||
}
|
||||
|
||||
func (a *Application) UpgradeChecker() *UpgradeChecker {
|
||||
return a.upgradeChecker
|
||||
}
|
||||
|
||||
// distributedDB returns the PostgreSQL database for distributed coordination,
|
||||
// or nil in standalone mode.
|
||||
func (a *Application) distributedDB() *gorm.DB {
|
||||
if a.distributed != nil {
|
||||
return a.authDB
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
@@ -335,6 +335,9 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
|
||||
appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||
}
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
appConfig.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||
}
|
||||
if settings.ApiKeys != nil {
|
||||
// API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
|
||||
// If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
|
||||
|
||||
@@ -231,6 +231,15 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
xlog.Error("error registering external backends", "error", err)
|
||||
}
|
||||
|
||||
// Start background upgrade checker for backends.
|
||||
// In distributed mode, uses PostgreSQL advisory lock so only one frontend
|
||||
// instance runs periodic checks (avoids duplicate upgrades across replicas).
|
||||
if len(options.BackendGalleries) > 0 {
|
||||
uc := NewUpgradeChecker(options, application.ModelLoader(), application.distributedDB())
|
||||
application.upgradeChecker = uc
|
||||
go uc.Run(options.Context)
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
xlog.Error("error loading config file", "error", err)
|
||||
|
||||
198
core/application/upgrade_checker.go
Normal file
198
core/application/upgrade_checker.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UpgradeChecker periodically checks for backend upgrades and optionally
|
||||
// auto-upgrades them. It caches the last check results for API queries.
|
||||
//
|
||||
// In standalone mode it runs a simple ticker loop.
|
||||
// In distributed mode it uses a PostgreSQL advisory lock so that only one
|
||||
// frontend instance performs periodic checks and auto-upgrades at a time.
|
||||
type UpgradeChecker struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelLoader *model.ModelLoader
|
||||
galleries []config.Gallery
|
||||
systemState *system.SystemState
|
||||
db *gorm.DB // non-nil in distributed mode
|
||||
|
||||
checkInterval time.Duration
|
||||
stop chan struct{}
|
||||
done chan struct{}
|
||||
triggerCh chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
lastUpgrades map[string]gallery.UpgradeInfo
|
||||
lastCheckTime time.Time
|
||||
}
|
||||
|
||||
// NewUpgradeChecker creates a new UpgradeChecker service.
|
||||
// Pass db=nil for standalone mode, or a *gorm.DB for distributed mode
|
||||
// (uses advisory locks so only one instance runs periodic checks).
|
||||
func NewUpgradeChecker(appConfig *config.ApplicationConfig, ml *model.ModelLoader, db *gorm.DB) *UpgradeChecker {
|
||||
return &UpgradeChecker{
|
||||
appConfig: appConfig,
|
||||
modelLoader: ml,
|
||||
galleries: appConfig.BackendGalleries,
|
||||
systemState: appConfig.SystemState,
|
||||
db: db,
|
||||
checkInterval: 6 * time.Hour,
|
||||
stop: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
triggerCh: make(chan struct{}, 1),
|
||||
lastUpgrades: make(map[string]gallery.UpgradeInfo),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the upgrade checker loop. It waits 30 seconds after startup,
|
||||
// performs an initial check, then re-checks every 6 hours.
|
||||
//
|
||||
// In distributed mode, periodic checks are guarded by a PostgreSQL advisory
|
||||
// lock so only one frontend instance runs them. On-demand triggers (TriggerCheck)
|
||||
// and the initial check always run locally for fast API response cache warming.
|
||||
func (uc *UpgradeChecker) Run(ctx context.Context) {
|
||||
defer close(uc.done)
|
||||
|
||||
// Initial delay: don't slow down startup
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
}
|
||||
|
||||
// First check always runs locally (to warm the cache on this instance)
|
||||
uc.runCheck(ctx)
|
||||
|
||||
if uc.db != nil {
|
||||
// Distributed mode: use advisory lock for periodic checks.
|
||||
// RunLeaderLoop ticks every checkInterval; only the lock holder executes.
|
||||
go advisorylock.RunLeaderLoop(ctx, uc.db, advisorylock.KeyBackendUpgradeCheck, uc.checkInterval, func() {
|
||||
uc.runCheck(ctx)
|
||||
})
|
||||
|
||||
// Still listen for on-demand triggers (from API / settings change)
|
||||
// and stop signal — these run on every instance.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-uc.triggerCh:
|
||||
uc.runCheck(ctx)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Standalone mode: simple ticker loop
|
||||
ticker := time.NewTicker(uc.checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
uc.runCheck(ctx)
|
||||
case <-uc.triggerCh:
|
||||
uc.runCheck(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the upgrade checker loop.
|
||||
func (uc *UpgradeChecker) Shutdown() {
|
||||
close(uc.stop)
|
||||
<-uc.done
|
||||
}
|
||||
|
||||
// TriggerCheck forces an immediate upgrade check on this instance.
|
||||
func (uc *UpgradeChecker) TriggerCheck() {
|
||||
select {
|
||||
case uc.triggerCh <- struct{}{}:
|
||||
default:
|
||||
// Already triggered, skip
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableUpgrades returns the cached upgrade check results.
|
||||
func (uc *UpgradeChecker) GetAvailableUpgrades() map[string]gallery.UpgradeInfo {
|
||||
uc.mu.RLock()
|
||||
defer uc.mu.RUnlock()
|
||||
|
||||
// Return a copy to avoid races
|
||||
result := make(map[string]gallery.UpgradeInfo, len(uc.lastUpgrades))
|
||||
for k, v := range uc.lastUpgrades {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
upgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState)
|
||||
|
||||
uc.mu.Lock()
|
||||
uc.lastCheckTime = time.Now()
|
||||
if err != nil {
|
||||
xlog.Debug("Backend upgrade check failed", "error", err)
|
||||
uc.mu.Unlock()
|
||||
return
|
||||
}
|
||||
uc.lastUpgrades = upgrades
|
||||
uc.mu.Unlock()
|
||||
|
||||
if len(upgrades) == 0 {
|
||||
xlog.Debug("All backends up to date")
|
||||
return
|
||||
}
|
||||
|
||||
// Log available upgrades
|
||||
for name, info := range upgrades {
|
||||
if info.AvailableVersion != "" {
|
||||
xlog.Info("Backend upgrade available",
|
||||
"backend", name,
|
||||
"installed", info.InstalledVersion,
|
||||
"available", info.AvailableVersion)
|
||||
} else {
|
||||
xlog.Info("Backend upgrade available (new build)",
|
||||
"backend", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-upgrade if enabled
|
||||
if uc.appConfig.AutoUpgradeBackends {
|
||||
for name, info := range upgrades {
|
||||
xlog.Info("Auto-upgrading backend", "backend", name,
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
if err := gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil); err != nil {
|
||||
xlog.Error("Failed to auto-upgrade backend",
|
||||
"backend", name, "error", err)
|
||||
} else {
|
||||
xlog.Info("Backend upgraded successfully", "backend", name,
|
||||
"version", info.AvailableVersion)
|
||||
}
|
||||
}
|
||||
// Re-check to update cache after upgrades
|
||||
if freshUpgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState); err == nil {
|
||||
uc.mu.Lock()
|
||||
uc.lastUpgrades = freshUpgrades
|
||||
uc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,27 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// SyncPinnedModelsToWatchdog reads pinned status from all model configs and updates the watchdog
|
||||
func (a *Application) SyncPinnedModelsToWatchdog() {
|
||||
cl := a.ModelConfigLoader()
|
||||
if cl == nil {
|
||||
return
|
||||
}
|
||||
wd := a.modelLoader.GetWatchDog()
|
||||
if wd == nil {
|
||||
return
|
||||
}
|
||||
configs := cl.GetAllModelsConfigs()
|
||||
var pinned []string
|
||||
for _, cfg := range configs {
|
||||
if cfg.IsPinned() {
|
||||
pinned = append(pinned, cfg.Name)
|
||||
}
|
||||
}
|
||||
wd.SetPinnedModels(pinned)
|
||||
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
|
||||
}
|
||||
|
||||
func (a *Application) StopWatchdog() error {
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
@@ -44,6 +65,9 @@ func (a *Application) startWatchdog() error {
|
||||
// Set the watchdog on the model loader
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Sync pinned models from config to the watchdog
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
// But memory reclaimer needs the Run() loop for periodic checking
|
||||
@@ -124,5 +148,8 @@ func (a *Application) RestartWatchdog() error {
|
||||
newWD.RestoreState(oldState)
|
||||
}
|
||||
|
||||
// Re-sync pinned models after restart
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,10 @@ import (
|
||||
|
||||
func Detection(
|
||||
sourceFile string,
|
||||
prompt string,
|
||||
points []float32,
|
||||
boxes []float32,
|
||||
threshold float32,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -35,7 +39,11 @@ func Detection(
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
Src: sourceFile,
|
||||
Prompt: prompt,
|
||||
Points: points,
|
||||
Boxes: boxes,
|
||||
Threshold: threshold,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
|
||||
@@ -36,6 +36,27 @@ type TokenUsage struct {
|
||||
Completion int
|
||||
TimingPromptProcessing float64
|
||||
TimingTokenGeneration float64
|
||||
ChatDeltas []*proto.ChatDelta // per-chunk deltas from C++ autoparser (only set during streaming)
|
||||
}
|
||||
|
||||
// HasChatDeltaContent returns true if any chat delta carries content or reasoning text.
|
||||
// Used to decide whether to prefer C++ autoparser deltas over Go-side tag extraction.
|
||||
func (t TokenUsage) HasChatDeltaContent() bool {
|
||||
for _, d := range t.ChatDeltas {
|
||||
if d.Content != "" || d.ReasoningContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChatDeltaReasoningAndContent extracts accumulated reasoning and content from chat deltas.
|
||||
func (t TokenUsage) ChatDeltaReasoningAndContent() (reasoning, content string) {
|
||||
for _, d := range t.ChatDeltas {
|
||||
content += d.Content
|
||||
reasoning += d.ReasoningContent
|
||||
}
|
||||
return reasoning, content
|
||||
}
|
||||
|
||||
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
||||
@@ -171,6 +192,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
||||
}
|
||||
|
||||
// Attach per-chunk chat deltas to tokenUsage so the callback can use them
|
||||
tokenUsage.ChatDeltas = reply.ChatDeltas
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
@@ -200,6 +224,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if len(msg) == 0 {
|
||||
tokenCallback("", tokenUsage)
|
||||
}
|
||||
|
||||
// Clear per-chunk deltas so they don't leak to the next chunk
|
||||
tokenUsage.ChatDeltas = nil
|
||||
})
|
||||
if len(allChatDeltas) > 0 {
|
||||
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
. "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -107,3 +108,111 @@ var _ = Describe("LLM tests", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TokenUsage ChatDelta helpers", func() {
|
||||
Describe("HasChatDeltaContent", func() {
|
||||
It("should return false when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when ChatDeltas is empty", func() {
|
||||
usage := TokenUsage{ChatDeltas: []*pb.ChatDelta{}}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when all deltas have empty content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "", ReasoningContent: ""},
|
||||
{Content: ""},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return true when a delta has content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello"},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has reasoning content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has both content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello", ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ChatDeltaReasoningAndContent", func() {
|
||||
It("should return empty strings when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate content from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "Hello"},
|
||||
{Content: " world"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(content).To(Equal("Hello world"))
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate reasoning from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "step 1"},
|
||||
{ReasoningContent: " step 2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("step 1 step 2"))
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should separate reasoning and content from mixed deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking"},
|
||||
{Content: "answer"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("thinking"))
|
||||
Expect(content).To(Equal("answer"))
|
||||
})
|
||||
|
||||
It("should handle deltas with both fields set", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "a", ReasoningContent: "r1"},
|
||||
{Content: "b", ReasoningContent: "r2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("r1r2"))
|
||||
Expect(content).To(Equal("ab"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -40,10 +40,17 @@ type BackendsUninstall struct {
|
||||
BackendsCMDFlags `embed:""`
|
||||
}
|
||||
|
||||
type BackendsUpgrade struct {
|
||||
BackendArgs []string `arg:"" optional:"" name:"backends" help:"Backend names to upgrade (empty = upgrade all)"`
|
||||
|
||||
BackendsCMDFlags `embed:""`
|
||||
}
|
||||
|
||||
type BackendsCMD struct {
|
||||
List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"`
|
||||
Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"`
|
||||
Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"`
|
||||
Upgrade BackendsUpgrade `cmd:"" help:"Upgrade backends to latest versions"`
|
||||
}
|
||||
|
||||
func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||
@@ -64,11 +71,27 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for upgrades
|
||||
upgrades, _ := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
|
||||
for _, backend := range backends {
|
||||
versionStr := ""
|
||||
if backend.Version != "" {
|
||||
versionStr = " v" + backend.Version
|
||||
}
|
||||
if backend.Installed {
|
||||
fmt.Printf(" * %s@%s (installed)\n", backend.Gallery.Name, backend.Name)
|
||||
if info, ok := upgrades[backend.Name]; ok {
|
||||
upgradeStr := info.AvailableVersion
|
||||
if upgradeStr == "" {
|
||||
upgradeStr = "new build"
|
||||
}
|
||||
fmt.Printf(" * %s@%s%s (installed, upgrade available: %s)\n", backend.Gallery.Name, backend.Name, versionStr, upgradeStr)
|
||||
} else {
|
||||
fmt.Printf(" * %s@%s%s (installed)\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf(" - %s@%s\n", backend.Gallery.Name, backend.Name)
|
||||
fmt.Printf(" - %s@%s%s\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -111,6 +134,79 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bu *BackendsUpgrade) Run(ctx *cliContext.Context) error {
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(bu.BackendGalleries), &galleries); err != nil {
|
||||
xlog.Error("unable to load galleries", "error", err)
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(bu.BackendsSystemPath),
|
||||
system.WithBackendPath(bu.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
upgrades, err := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for upgrades: %w", err)
|
||||
}
|
||||
|
||||
if len(upgrades) == 0 {
|
||||
fmt.Println("All backends are up to date.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter to specified backends if args given
|
||||
toUpgrade := upgrades
|
||||
if len(bu.BackendArgs) > 0 {
|
||||
toUpgrade = make(map[string]gallery.UpgradeInfo)
|
||||
for _, name := range bu.BackendArgs {
|
||||
if info, ok := upgrades[name]; ok {
|
||||
toUpgrade[name] = info
|
||||
} else {
|
||||
fmt.Printf("Backend %s: no upgrade available\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toUpgrade) == 0 {
|
||||
fmt.Println("No upgrades to apply.")
|
||||
return nil
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
for name, info := range toUpgrade {
|
||||
versionStr := ""
|
||||
if info.AvailableVersion != "" {
|
||||
versionStr = " to v" + info.AvailableVersion
|
||||
}
|
||||
fmt.Printf("Upgrading %s%s...\n", name, versionStr)
|
||||
|
||||
progressBar := progressbar.NewOptions(
|
||||
1000,
|
||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading %s", name)),
|
||||
progressbar.OptionShowBytes(false),
|
||||
progressbar.OptionClearOnFinish(),
|
||||
)
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
v := int(percentage * 10)
|
||||
if err := progressBar.Set(v); err != nil {
|
||||
xlog.Error("error updating progress bar", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback); err != nil {
|
||||
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
||||
} else {
|
||||
fmt.Printf("Backend %s upgraded successfully\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
|
||||
for _, backendName := range bu.BackendArgs {
|
||||
xlog.Info("uninstalling backend", "backend", backendName)
|
||||
|
||||
@@ -47,6 +47,7 @@ type RunCMD struct {
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
@@ -62,6 +63,7 @@ type RunCMD struct {
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
@@ -295,6 +297,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.DisableWebUI)
|
||||
}
|
||||
|
||||
if r.OllamaAPIRootEndpoint {
|
||||
opts = append(opts, config.EnableOllamaAPIRootEndpoint)
|
||||
}
|
||||
|
||||
if r.DisableGalleryEndpoint {
|
||||
opts = append(opts, config.DisableGalleryEndpoint)
|
||||
}
|
||||
@@ -485,6 +491,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableBackendGalleriesAutoload)
|
||||
}
|
||||
|
||||
if r.AutoUpgradeBackends {
|
||||
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
||||
}
|
||||
|
||||
if r.PreloadBackendOnly {
|
||||
_, err := application.New(opts...)
|
||||
return err
|
||||
|
||||
@@ -512,11 +512,9 @@ func (s *backendSupervisor) stopBackend(backend string) {
|
||||
|
||||
// Network I/O outside the lock
|
||||
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||
}
|
||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||
if err := client.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
||||
@@ -692,13 +690,13 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
|
||||
// backend.delete — stop backend + delete files (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS backend.delete event")
|
||||
var req messaging.BackendDeleteRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
xlog.Info("Received NATS backend.delete event", "backend", req.Backend)
|
||||
|
||||
// Stop if running this backend
|
||||
if s.isRunning(req.Backend) {
|
||||
@@ -774,10 +772,8 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
if targetAddr != "" {
|
||||
// Best-effort gRPC Free()
|
||||
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||
}
|
||||
if err := client.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ type ApplicationConfig struct {
|
||||
Federated bool
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
OpaqueErrors bool
|
||||
UseSubtleKeyComparison bool
|
||||
@@ -56,6 +57,7 @@ type ApplicationConfig struct {
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries, AutoloadBackendGalleries bool
|
||||
AutoUpgradeBackends bool
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
@@ -263,6 +265,10 @@ var DisableWebUI = func(o *ApplicationConfig) {
|
||||
o.DisableWebUI = true
|
||||
}
|
||||
|
||||
var EnableOllamaAPIRootEndpoint = func(o *ApplicationConfig) {
|
||||
o.OllamaAPIRootEndpoint = true
|
||||
}
|
||||
|
||||
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
||||
o.DisableRuntimeSettings = true
|
||||
}
|
||||
@@ -385,6 +391,10 @@ var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) {
|
||||
o.AutoloadBackendGalleries = true
|
||||
}
|
||||
|
||||
func WithAutoUpgradeBackends(v bool) AppOption {
|
||||
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
||||
}
|
||||
|
||||
var EnableFederated = func(o *ApplicationConfig) {
|
||||
o.Federated = true
|
||||
}
|
||||
@@ -857,6 +867,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
backendGalleries := o.BackendGalleries
|
||||
autoloadGalleries := o.AutoloadGalleries
|
||||
autoloadBackendGalleries := o.AutoloadBackendGalleries
|
||||
autoUpgradeBackends := o.AutoUpgradeBackends
|
||||
apiKeys := o.ApiKeys
|
||||
agentJobRetentionDays := o.AgentJobRetentionDays
|
||||
|
||||
@@ -930,6 +941,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
BackendGalleries: &backendGalleries,
|
||||
AutoloadGalleries: &autoloadGalleries,
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
AutoUpgradeBackends: &autoUpgradeBackends,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
@@ -1078,6 +1090,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.AutoloadBackendGalleries != nil {
|
||||
o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||
}
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
o.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||
}
|
||||
if settings.AgentJobRetentionDays != nil {
|
||||
o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
|
||||
}
|
||||
|
||||
@@ -119,6 +119,13 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(*rs.AgentJobRetentionDays).To(Equal(30))
|
||||
})
|
||||
|
||||
It("should include auto_upgrade_backends", func() {
|
||||
appConfig := &ApplicationConfig{AutoUpgradeBackends: true}
|
||||
rs := appConfig.ToRuntimeSettings()
|
||||
Expect(rs.AutoUpgradeBackends).ToNot(BeNil())
|
||||
Expect(*rs.AutoUpgradeBackends).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use default timeouts when not set", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
|
||||
@@ -426,6 +433,14 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(appConfig.AutoloadBackendGalleries).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should apply auto_upgrade_backends setting", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
v := true
|
||||
rs := &RuntimeSettings{AutoUpgradeBackends: &v}
|
||||
appConfig.ApplyRuntimeSettings(rs)
|
||||
Expect(appConfig.AutoUpgradeBackends).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should apply agent settings", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
|
||||
@@ -465,6 +480,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Federated: true,
|
||||
AutoloadGalleries: true,
|
||||
AutoloadBackendGalleries: false,
|
||||
AutoUpgradeBackends: true,
|
||||
AgentJobRetentionDays: 60,
|
||||
}
|
||||
|
||||
@@ -496,6 +512,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(target.Federated).To(Equal(original.Federated))
|
||||
Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries))
|
||||
Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries))
|
||||
Expect(target.AutoUpgradeBackends).To(Equal(original.AutoUpgradeBackends))
|
||||
Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays))
|
||||
})
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95},
|
||||
"gemma-4": {"min_p":0,"presence_penalty":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
@@ -53,5 +54,5 @@
|
||||
"grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||
"mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}
|
||||
},
|
||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-4","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||
}
|
||||
|
||||
132
core/config/meta/build.go
Normal file
132
core/config/meta/build.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
cachedMetadata *ConfigMetadata
|
||||
cacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
// BuildConfigMetadata reflects on the given struct type (ModelConfig),
|
||||
// merges the enrichment registry, and returns the full ConfigMetadata.
|
||||
// The result is cached in memory after the first call.
|
||||
func BuildConfigMetadata(modelConfigType reflect.Type) *ConfigMetadata {
|
||||
cacheMu.RLock()
|
||||
if cachedMetadata != nil {
|
||||
cacheMu.RUnlock()
|
||||
return cachedMetadata
|
||||
}
|
||||
cacheMu.RUnlock()
|
||||
|
||||
cacheMu.Lock()
|
||||
defer cacheMu.Unlock()
|
||||
|
||||
if cachedMetadata != nil {
|
||||
return cachedMetadata
|
||||
}
|
||||
|
||||
cachedMetadata = buildConfigMetadataUncached(modelConfigType, DefaultRegistry())
|
||||
return cachedMetadata
|
||||
}
|
||||
|
||||
// buildConfigMetadataUncached does the actual work without caching.
|
||||
func buildConfigMetadataUncached(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||
fields := WalkModelConfig(modelConfigType)
|
||||
|
||||
for i := range fields {
|
||||
override, ok := registry[fields[i].Path]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
applyOverride(&fields[i], override)
|
||||
}
|
||||
|
||||
allSections := DefaultSections()
|
||||
|
||||
sectionOrder := make(map[string]int, len(allSections))
|
||||
for _, s := range allSections {
|
||||
sectionOrder[s.ID] = s.Order
|
||||
}
|
||||
|
||||
sort.SliceStable(fields, func(i, j int) bool {
|
||||
si := sectionOrder[fields[i].Section]
|
||||
sj := sectionOrder[fields[j].Section]
|
||||
if si != sj {
|
||||
return si < sj
|
||||
}
|
||||
return fields[i].Order < fields[j].Order
|
||||
})
|
||||
|
||||
usedSections := make(map[string]bool)
|
||||
for _, f := range fields {
|
||||
usedSections[f.Section] = true
|
||||
}
|
||||
|
||||
var sections []Section
|
||||
for _, s := range allSections {
|
||||
if usedSections[s.ID] {
|
||||
sections = append(sections, s)
|
||||
}
|
||||
}
|
||||
|
||||
return &ConfigMetadata{
|
||||
Sections: sections,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
// applyOverride merges non-zero override values into the field.
|
||||
func applyOverride(f *FieldMeta, o FieldMetaOverride) {
|
||||
if o.Section != "" {
|
||||
f.Section = o.Section
|
||||
}
|
||||
if o.Label != "" {
|
||||
f.Label = o.Label
|
||||
}
|
||||
if o.Description != "" {
|
||||
f.Description = o.Description
|
||||
}
|
||||
if o.Component != "" {
|
||||
f.Component = o.Component
|
||||
}
|
||||
if o.Placeholder != "" {
|
||||
f.Placeholder = o.Placeholder
|
||||
}
|
||||
if o.Default != nil {
|
||||
f.Default = o.Default
|
||||
}
|
||||
if o.Min != nil {
|
||||
f.Min = o.Min
|
||||
}
|
||||
if o.Max != nil {
|
||||
f.Max = o.Max
|
||||
}
|
||||
if o.Step != nil {
|
||||
f.Step = o.Step
|
||||
}
|
||||
if o.Options != nil {
|
||||
f.Options = o.Options
|
||||
}
|
||||
if o.AutocompleteProvider != "" {
|
||||
f.AutocompleteProvider = o.AutocompleteProvider
|
||||
}
|
||||
if o.VRAMImpact {
|
||||
f.VRAMImpact = true
|
||||
}
|
||||
if o.Advanced {
|
||||
f.Advanced = true
|
||||
}
|
||||
if o.Order != 0 {
|
||||
f.Order = o.Order
|
||||
}
|
||||
}
|
||||
|
||||
// BuildForTest builds metadata without caching, for use in tests.
|
||||
func BuildForTest(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||
return buildConfigMetadataUncached(modelConfigType, registry)
|
||||
}
|
||||
|
||||
211
core/config/meta/build_test.go
Normal file
211
core/config/meta/build_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
func TestBuildConfigMetadata(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
if len(md.Sections) == 0 {
|
||||
t.Fatal("expected sections, got 0")
|
||||
}
|
||||
if len(md.Fields) == 0 {
|
||||
t.Fatal("expected fields, got 0")
|
||||
}
|
||||
|
||||
// Verify sections are ordered
|
||||
for i := 1; i < len(md.Sections); i++ {
|
||||
if md.Sections[i].Order < md.Sections[i-1].Order {
|
||||
t.Errorf("sections not ordered: %s (order=%d) before %s (order=%d)",
|
||||
md.Sections[i-1].ID, md.Sections[i-1].Order,
|
||||
md.Sections[i].ID, md.Sections[i].Order)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryOverrides(t *testing.T) {
|
||||
registry := map[string]meta.FieldMetaOverride{
|
||||
"name": {
|
||||
Label: "My Custom Label",
|
||||
Description: "Custom description",
|
||||
Component: "textarea",
|
||||
Order: 999,
|
||||
},
|
||||
}
|
||||
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), registry)
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
f, ok := byPath["name"]
|
||||
if !ok {
|
||||
t.Fatal("field 'name' not found")
|
||||
}
|
||||
if f.Label != "My Custom Label" {
|
||||
t.Errorf("expected label 'My Custom Label', got %q", f.Label)
|
||||
}
|
||||
if f.Description != "Custom description" {
|
||||
t.Errorf("expected description 'Custom description', got %q", f.Description)
|
||||
}
|
||||
if f.Component != "textarea" {
|
||||
t.Errorf("expected component 'textarea', got %q", f.Component)
|
||||
}
|
||||
if f.Order != 999 {
|
||||
t.Errorf("expected order 999, got %d", f.Order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregisteredFieldsGetDefaults(t *testing.T) {
|
||||
// Use empty registry - all fields should still get auto-generated metadata
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), map[string]meta.FieldMetaOverride{})
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// context_size should still exist with auto-generated label
|
||||
f, ok := byPath["context_size"]
|
||||
if !ok {
|
||||
t.Fatal("field 'context_size' not found")
|
||||
}
|
||||
if f.Label == "" {
|
||||
t.Error("expected auto-generated label, got empty")
|
||||
}
|
||||
if f.UIType != "int" {
|
||||
t.Errorf("expected UIType 'int', got %q", f.UIType)
|
||||
}
|
||||
if f.Component == "" {
|
||||
t.Error("expected auto-generated component, got empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistryOverridesApply(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Verify enriched fields got their overrides
|
||||
tests := []struct {
|
||||
path string
|
||||
label string
|
||||
description string
|
||||
vramImpact bool
|
||||
}{
|
||||
{"context_size", "Context Size", "Maximum context window in tokens", true},
|
||||
{"gpu_layers", "GPU Layers", "Number of layers to offload to GPU (-1 = all)", true},
|
||||
{"backend", "Backend", "The inference backend to use (e.g. llama-cpp, vllm, diffusers)", false},
|
||||
{"parameters.temperature", "Temperature", "Sampling temperature (higher = more creative, lower = more deterministic)", false},
|
||||
{"template.chat", "Chat Template", "Go template for chat completion requests", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Label != tt.label {
|
||||
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||
}
|
||||
if f.Description != tt.description {
|
||||
t.Errorf("field %q: expected description %q, got %q", tt.path, tt.description, f.Description)
|
||||
}
|
||||
if f.VRAMImpact != tt.vramImpact {
|
||||
t.Errorf("field %q: expected vramImpact=%v, got %v", tt.path, tt.vramImpact, f.VRAMImpact)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticOptionsFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Fields with static options should have Options populated and no AutocompleteProvider
|
||||
staticFields := []string{"quantization", "cache_type_k", "cache_type_v", "diffusers.pipeline_type", "diffusers.scheduler_type"}
|
||||
for _, path := range staticFields {
|
||||
f, ok := byPath[path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", path)
|
||||
continue
|
||||
}
|
||||
if len(f.Options) == 0 {
|
||||
t.Errorf("field %q: expected Options to be populated", path)
|
||||
}
|
||||
if f.AutocompleteProvider != "" {
|
||||
t.Errorf("field %q: expected no AutocompleteProvider, got %q", path, f.AutocompleteProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicProviderFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Fields with dynamic providers should have AutocompleteProvider and no Options
|
||||
dynamicFields := map[string]string{
|
||||
"backend": meta.ProviderBackends,
|
||||
"pipeline.llm": meta.ProviderModelsChat,
|
||||
"pipeline.tts": meta.ProviderModelsTTS,
|
||||
"pipeline.transcription": meta.ProviderModelsTranscript,
|
||||
"pipeline.vad": meta.ProviderModelsVAD,
|
||||
}
|
||||
for path, expectedProvider := range dynamicFields {
|
||||
f, ok := byPath[path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", path)
|
||||
continue
|
||||
}
|
||||
if f.AutocompleteProvider != expectedProvider {
|
||||
t.Errorf("field %q: expected AutocompleteProvider %q, got %q", path, expectedProvider, f.AutocompleteProvider)
|
||||
}
|
||||
if len(f.Options) != 0 {
|
||||
t.Errorf("field %q: expected no Options, got %d", path, len(f.Options))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVRAMImpactFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
var vramFields []string
|
||||
for _, f := range md.Fields {
|
||||
if f.VRAMImpact {
|
||||
vramFields = append(vramFields, f.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vramFields) == 0 {
|
||||
t.Error("expected some VRAM impact fields, got 0")
|
||||
}
|
||||
|
||||
// context_size and gpu_layers should be marked
|
||||
expected := map[string]bool{"context_size": true, "gpu_layers": true}
|
||||
for _, path := range vramFields {
|
||||
if expected[path] {
|
||||
delete(expected, path)
|
||||
}
|
||||
}
|
||||
for path := range expected {
|
||||
t.Errorf("expected VRAM impact field %q not found", path)
|
||||
}
|
||||
}
|
||||
79
core/config/meta/constants.go
Normal file
79
core/config/meta/constants.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package meta
|
||||
|
||||
// Dynamic autocomplete provider constants (runtime lookup required).
|
||||
const (
|
||||
ProviderBackends = "backends"
|
||||
ProviderModels = "models"
|
||||
ProviderModelsChat = "models:chat"
|
||||
ProviderModelsTTS = "models:tts"
|
||||
ProviderModelsTranscript = "models:transcript"
|
||||
ProviderModelsVAD = "models:vad"
|
||||
)
|
||||
|
||||
// Static option lists embedded directly in field metadata.
|
||||
|
||||
var QuantizationOptions = []FieldOption{
|
||||
{Value: "q4_0", Label: "Q4_0"},
|
||||
{Value: "q4_1", Label: "Q4_1"},
|
||||
{Value: "q5_0", Label: "Q5_0"},
|
||||
{Value: "q5_1", Label: "Q5_1"},
|
||||
{Value: "q8_0", Label: "Q8_0"},
|
||||
{Value: "q2_K", Label: "Q2_K"},
|
||||
{Value: "q3_K_S", Label: "Q3_K_S"},
|
||||
{Value: "q3_K_M", Label: "Q3_K_M"},
|
||||
{Value: "q3_K_L", Label: "Q3_K_L"},
|
||||
{Value: "q4_K_S", Label: "Q4_K_S"},
|
||||
{Value: "q4_K_M", Label: "Q4_K_M"},
|
||||
{Value: "q5_K_S", Label: "Q5_K_S"},
|
||||
{Value: "q5_K_M", Label: "Q5_K_M"},
|
||||
{Value: "q6_K", Label: "Q6_K"},
|
||||
}
|
||||
|
||||
var CacheTypeOptions = []FieldOption{
|
||||
{Value: "f16", Label: "F16"},
|
||||
{Value: "f32", Label: "F32"},
|
||||
{Value: "q8_0", Label: "Q8_0"},
|
||||
{Value: "q4_0", Label: "Q4_0"},
|
||||
{Value: "q4_1", Label: "Q4_1"},
|
||||
{Value: "q5_0", Label: "Q5_0"},
|
||||
{Value: "q5_1", Label: "Q5_1"},
|
||||
}
|
||||
|
||||
var DiffusersPipelineOptions = []FieldOption{
|
||||
{Value: "StableDiffusionPipeline", Label: "StableDiffusionPipeline"},
|
||||
{Value: "StableDiffusionImg2ImgPipeline", Label: "StableDiffusionImg2ImgPipeline"},
|
||||
{Value: "StableDiffusionXLPipeline", Label: "StableDiffusionXLPipeline"},
|
||||
{Value: "StableDiffusionXLImg2ImgPipeline", Label: "StableDiffusionXLImg2ImgPipeline"},
|
||||
{Value: "StableDiffusionDepth2ImgPipeline", Label: "StableDiffusionDepth2ImgPipeline"},
|
||||
{Value: "DiffusionPipeline", Label: "DiffusionPipeline"},
|
||||
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||
}
|
||||
|
||||
var UsecaseOptions = []FieldOption{
|
||||
{Value: "chat", Label: "Chat"},
|
||||
{Value: "completion", Label: "Completion"},
|
||||
{Value: "edit", Label: "Edit"},
|
||||
{Value: "embeddings", Label: "Embeddings"},
|
||||
{Value: "rerank", Label: "Rerank"},
|
||||
{Value: "image", Label: "Image"},
|
||||
{Value: "transcript", Label: "Transcript"},
|
||||
{Value: "tts", Label: "TTS"},
|
||||
{Value: "sound_generation", Label: "Sound Generation"},
|
||||
{Value: "tokenize", Label: "Tokenize"},
|
||||
{Value: "vad", Label: "VAD"},
|
||||
{Value: "video", Label: "Video"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
}
|
||||
|
||||
var DiffusersSchedulerOptions = []FieldOption{
|
||||
{Value: "ddim", Label: "DDIM"},
|
||||
{Value: "ddpm", Label: "DDPM"},
|
||||
{Value: "pndm", Label: "PNDM"},
|
||||
{Value: "lms", Label: "LMS"},
|
||||
{Value: "euler", Label: "Euler"},
|
||||
{Value: "euler_a", Label: "Euler A"},
|
||||
{Value: "dpm_multistep", Label: "DPM Multistep"},
|
||||
{Value: "dpm_singlestep", Label: "DPM Singlestep"},
|
||||
{Value: "heun", Label: "Heun"},
|
||||
{Value: "unipc", Label: "UniPC"},
|
||||
}
|
||||
241
core/config/meta/reflect.go
Normal file
241
core/config/meta/reflect.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// WalkModelConfig uses reflection to discover all exported, YAML-tagged fields
|
||||
// in the given struct type (expected to be config.ModelConfig) and returns a
|
||||
// slice of FieldMeta with sensible defaults derived from the type information.
|
||||
func WalkModelConfig(t reflect.Type) []FieldMeta {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
var fields []FieldMeta
|
||||
walkStruct(t, "", &fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
// walkStruct recursively walks a struct type, collecting FieldMeta entries.
|
||||
// prefix is the dot-path prefix for nested structs (e.g. "function.grammar.").
|
||||
func walkStruct(t reflect.Type, prefix string, out *[]FieldMeta) {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
for sf := range t.Fields() {
|
||||
if !sf.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
yamlTag := sf.Tag.Get("yaml")
|
||||
if yamlTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
yamlKey, opts := parseTag(yamlTag)
|
||||
|
||||
// Handle inline embedding (e.g. LLMConfig `yaml:",inline"`)
|
||||
if opts.contains("inline") {
|
||||
ft := sf.Type
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
if ft.Kind() == reflect.Struct {
|
||||
walkStruct(ft, prefix, out)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If no yaml key and it's an embedded struct without inline, skip unknown pattern
|
||||
if yamlKey == "" {
|
||||
ft := sf.Type
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
// Anonymous struct without yaml tag - treat as inline
|
||||
if sf.Anonymous && ft.Kind() == reflect.Struct {
|
||||
walkStruct(ft, prefix, out)
|
||||
continue
|
||||
}
|
||||
// Named field without yaml tag - skip
|
||||
continue
|
||||
}
|
||||
|
||||
ft := sf.Type
|
||||
isPtr := ft.Kind() == reflect.Pointer
|
||||
if isPtr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
// Named nested struct (not a special type) -> recurse with prefix
|
||||
if ft.Kind() == reflect.Struct && !isSpecialType(ft) {
|
||||
nestedPrefix := prefix + yamlKey + "."
|
||||
walkStruct(ft, nestedPrefix, out)
|
||||
continue
|
||||
}
|
||||
|
||||
// Leaf field
|
||||
path := prefix + yamlKey
|
||||
goType := sf.Type.String()
|
||||
uiType, component := inferUIType(sf.Type)
|
||||
section := inferSection(prefix)
|
||||
label := labelFromKey(yamlKey)
|
||||
|
||||
*out = append(*out, FieldMeta{
|
||||
Path: path,
|
||||
YAMLKey: yamlKey,
|
||||
GoType: goType,
|
||||
UIType: uiType,
|
||||
Pointer: isPtr,
|
||||
Section: section,
|
||||
Label: label,
|
||||
Component: component,
|
||||
Order: len(*out),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// isSpecialType returns true for struct types that should be treated as leaf
|
||||
// values rather than recursed into (e.g. custom JSON marshalers).
|
||||
func isSpecialType(t reflect.Type) bool {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
name := t.Name()
|
||||
// LogprobsValue, URI types are leaf values despite being structs
|
||||
switch name {
|
||||
case "LogprobsValue", "URI":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// inferUIType maps a Go reflect.Type to a UI type string and default component.
|
||||
func inferUIType(t reflect.Type) (uiType, component string) {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool", "toggle"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return "int", "number"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "int", "number"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "float", "number"
|
||||
case reflect.String:
|
||||
return "string", "input"
|
||||
case reflect.Slice:
|
||||
elem := t.Elem()
|
||||
if elem.Kind() == reflect.String {
|
||||
return "[]string", "string-list"
|
||||
}
|
||||
if elem.Kind() == reflect.Pointer {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if elem.Kind() == reflect.Struct {
|
||||
return "[]object", "json-editor"
|
||||
}
|
||||
return "[]any", "json-editor"
|
||||
case reflect.Map:
|
||||
return "map", "map-editor"
|
||||
case reflect.Struct:
|
||||
// Special types treated as leaves
|
||||
if isSpecialType(t) {
|
||||
return "bool", "toggle" // LogprobsValue
|
||||
}
|
||||
return "object", "json-editor"
|
||||
default:
|
||||
return "any", "input"
|
||||
}
|
||||
}
|
||||
|
||||
// inferSection determines the config section from the dot-path prefix.
|
||||
func inferSection(prefix string) string {
|
||||
if prefix == "" {
|
||||
return "general"
|
||||
}
|
||||
// Remove trailing dot
|
||||
p := strings.TrimSuffix(prefix, ".")
|
||||
|
||||
// Use the top-level prefix to determine section
|
||||
parts := strings.SplitN(p, ".", 2)
|
||||
top := parts[0]
|
||||
|
||||
switch top {
|
||||
case "parameters":
|
||||
return "parameters"
|
||||
case "template":
|
||||
return "templates"
|
||||
case "function":
|
||||
return "functions"
|
||||
case "reasoning":
|
||||
return "reasoning"
|
||||
case "diffusers":
|
||||
return "diffusers"
|
||||
case "tts":
|
||||
return "tts"
|
||||
case "pipeline":
|
||||
return "pipeline"
|
||||
case "grpc":
|
||||
return "grpc"
|
||||
case "agent":
|
||||
return "agent"
|
||||
case "mcp":
|
||||
return "mcp"
|
||||
case "feature_flags":
|
||||
return "other"
|
||||
case "limit_mm_per_prompt":
|
||||
return "llm"
|
||||
default:
|
||||
return "other"
|
||||
}
|
||||
}
|
||||
|
||||
// labelFromKey converts a yaml key like "context_size" to "Context Size".
|
||||
func labelFromKey(key string) string {
|
||||
parts := strings.Split(key, "_")
|
||||
for i, p := range parts {
|
||||
if len(p) > 0 {
|
||||
runes := []rune(p)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
parts[i] = string(runes)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// tagOptions is a set of comma-separated yaml tag options.
|
||||
type tagOptions string
|
||||
|
||||
func (o tagOptions) contains(optName string) bool {
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var name string
|
||||
if name, s, _ = strings.Cut(s, ","); name == optName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTag splits a yaml struct tag into the key name and options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if tag == "" {
|
||||
return "", ""
|
||||
}
|
||||
before, after, found := strings.Cut(tag, ",")
|
||||
if found {
|
||||
return before, tagOptions(after)
|
||||
}
|
||||
return tag, ""
|
||||
}
|
||||
|
||||
208
core/config/meta/reflect_test.go
Normal file
208
core/config/meta/reflect_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
func TestWalkModelConfig(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
if len(fields) == 0 {
|
||||
t.Fatal("expected fields from ModelConfig, got 0")
|
||||
}
|
||||
|
||||
// Build a lookup by path
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Verify some top-level fields exist
|
||||
for _, path := range []string{"name", "backend", "cuda", "step"} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify inline LLMConfig fields appear at top level (no prefix)
|
||||
for _, path := range []string{"context_size", "gpu_layers", "threads", "mmap"} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected inline LLMConfig field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify nested struct fields have correct prefix
|
||||
for _, path := range []string{
|
||||
"template.chat",
|
||||
"template.completion",
|
||||
"template.use_tokenizer_template",
|
||||
"function.grammar.parallel_calls",
|
||||
"function.grammar.mixed_mode",
|
||||
"diffusers.pipeline_type",
|
||||
"diffusers.cuda",
|
||||
"pipeline.llm",
|
||||
"pipeline.tts",
|
||||
"reasoning.disable",
|
||||
"agent.max_iterations",
|
||||
"grpc.attempts",
|
||||
} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected nested field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify PredictionOptions fields have parameters. prefix
|
||||
for _, path := range []string{
|
||||
"parameters.temperature",
|
||||
"parameters.top_p",
|
||||
"parameters.top_k",
|
||||
"parameters.max_tokens",
|
||||
"parameters.seed",
|
||||
} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected parameters field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify TTSConfig fields have tts. prefix
|
||||
if _, ok := byPath["tts.voice"]; !ok {
|
||||
t.Error("expected tts.voice field not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipsYAMLDashFields(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// modelConfigFile has yaml:"-" tag, should be skipped
|
||||
for _, f := range fields {
|
||||
if f.Path == "modelConfigFile" || f.Path == "modelTemplate" {
|
||||
t.Errorf("field %q should have been skipped (yaml:\"-\")", f.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapping(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
uiType string
|
||||
pointer bool
|
||||
}{
|
||||
{"name", "string", false},
|
||||
{"cuda", "bool", false},
|
||||
{"context_size", "int", true},
|
||||
{"gpu_layers", "int", true},
|
||||
{"threads", "int", true},
|
||||
{"f16", "bool", true},
|
||||
{"mmap", "bool", true},
|
||||
{"stopwords", "[]string", false},
|
||||
{"roles", "map", false},
|
||||
{"parameters.temperature", "float", true},
|
||||
{"parameters.top_k", "int", true},
|
||||
{"function.grammar.parallel_calls", "bool", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.UIType != tt.uiType {
|
||||
t.Errorf("field %q: expected UIType %q, got %q", tt.path, tt.uiType, f.UIType)
|
||||
}
|
||||
if f.Pointer != tt.pointer {
|
||||
t.Errorf("field %q: expected Pointer=%v, got %v", tt.path, tt.pointer, f.Pointer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSectionAssignment(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
section string
|
||||
}{
|
||||
{"name", "general"},
|
||||
{"backend", "general"},
|
||||
{"context_size", "general"}, // inline LLMConfig -> no prefix -> general
|
||||
{"parameters.temperature", "parameters"},
|
||||
{"template.chat", "templates"},
|
||||
{"function.grammar.parallel_calls", "functions"},
|
||||
{"diffusers.cuda", "diffusers"},
|
||||
{"pipeline.llm", "pipeline"},
|
||||
{"reasoning.disable", "reasoning"},
|
||||
{"agent.max_iterations", "agent"},
|
||||
{"grpc.attempts", "grpc"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Section != tt.section {
|
||||
t.Errorf("field %q: expected section %q, got %q", tt.path, tt.section, f.Section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLabelGeneration(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
label string
|
||||
}{
|
||||
{"context_size", "Context Size"},
|
||||
{"gpu_layers", "Gpu Layers"},
|
||||
{"name", "Name"},
|
||||
{"cuda", "Cuda"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Label != tt.label {
|
||||
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldCount(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
// We expect a large number of fields (100+) given the config complexity
|
||||
if len(fields) < 80 {
|
||||
t.Errorf("expected at least 80 fields, got %d", len(fields))
|
||||
}
|
||||
t.Logf("Total fields discovered: %d", len(fields))
|
||||
}
|
||||
324
core/config/meta/registry.go
Normal file
324
core/config/meta/registry.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package meta
|
||||
|
||||
// DefaultRegistry returns enrichment overrides for the ~30 most commonly used
|
||||
// config fields. Fields not listed here still appear with auto-generated
|
||||
// labels and type-inferred components.
|
||||
func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
f64 := func(v float64) *float64 { return &v }
|
||||
|
||||
return map[string]FieldMetaOverride{
|
||||
// --- General ---
|
||||
"name": {
|
||||
Section: "general",
|
||||
Label: "Model Name",
|
||||
Description: "Unique identifier for this model configuration",
|
||||
Component: "input",
|
||||
Order: 0,
|
||||
},
|
||||
"backend": {
|
||||
Section: "general",
|
||||
Label: "Backend",
|
||||
Description: "The inference backend to use (e.g. llama-cpp, vllm, diffusers)",
|
||||
Component: "select",
|
||||
AutocompleteProvider: ProviderBackends,
|
||||
Order: 1,
|
||||
},
|
||||
"description": {
|
||||
Section: "general",
|
||||
Label: "Description",
|
||||
Description: "Human-readable description of what this model does",
|
||||
Component: "textarea",
|
||||
Order: 2,
|
||||
},
|
||||
"usage": {
|
||||
Section: "general",
|
||||
Label: "Usage",
|
||||
Description: "Usage instructions or notes",
|
||||
Component: "textarea",
|
||||
Advanced: true,
|
||||
Order: 3,
|
||||
},
|
||||
"cuda": {
|
||||
Section: "general",
|
||||
Label: "CUDA",
|
||||
Description: "Explicitly enable CUDA acceleration",
|
||||
Order: 5,
|
||||
},
|
||||
"known_usecases": {
|
||||
Section: "general",
|
||||
Label: "Known Use Cases",
|
||||
Description: "Capabilities this model supports",
|
||||
Component: "string-list",
|
||||
Options: UsecaseOptions,
|
||||
Order: 6,
|
||||
},
|
||||
|
||||
// --- LLM ---
|
||||
"context_size": {
|
||||
Section: "llm",
|
||||
Label: "Context Size",
|
||||
Description: "Maximum context window in tokens",
|
||||
Component: "number",
|
||||
VRAMImpact: true,
|
||||
Order: 10,
|
||||
},
|
||||
"gpu_layers": {
|
||||
Section: "llm",
|
||||
Label: "GPU Layers",
|
||||
Description: "Number of layers to offload to GPU (-1 = all)",
|
||||
Component: "number",
|
||||
Min: f64(-1),
|
||||
VRAMImpact: true,
|
||||
Order: 11,
|
||||
},
|
||||
"threads": {
|
||||
Section: "llm",
|
||||
Label: "Threads",
|
||||
Description: "Number of CPU threads for inference",
|
||||
Component: "number",
|
||||
Min: f64(1),
|
||||
Order: 12,
|
||||
},
|
||||
"f16": {
|
||||
Section: "llm",
|
||||
Label: "F16",
|
||||
Description: "Use 16-bit floating point for key/value cache",
|
||||
Order: 13,
|
||||
},
|
||||
"mmap": {
|
||||
Section: "llm",
|
||||
Label: "Memory Map",
|
||||
Description: "Use memory-mapped files for model loading",
|
||||
Order: 14,
|
||||
},
|
||||
"mmlock": {
|
||||
Section: "llm",
|
||||
Label: "Memory Lock",
|
||||
Description: "Lock model memory to prevent swapping",
|
||||
Advanced: true,
|
||||
Order: 15,
|
||||
},
|
||||
"low_vram": {
|
||||
Section: "llm",
|
||||
Label: "Low VRAM",
|
||||
Description: "Optimize for systems with limited GPU memory",
|
||||
VRAMImpact: true,
|
||||
Order: 16,
|
||||
},
|
||||
"embeddings": {
|
||||
Section: "llm",
|
||||
Label: "Embeddings",
|
||||
Description: "Enable embedding generation mode",
|
||||
Order: 17,
|
||||
},
|
||||
"quantization": {
|
||||
Section: "llm",
|
||||
Label: "Quantization",
|
||||
Description: "Quantization method (e.g. q4_0, q5_1, q8_0)",
|
||||
Component: "select",
|
||||
Options: QuantizationOptions,
|
||||
Advanced: true,
|
||||
Order: 20,
|
||||
},
|
||||
"flash_attention": {
|
||||
Section: "llm",
|
||||
Label: "Flash Attention",
|
||||
Description: "Enable flash attention for faster inference",
|
||||
Component: "input",
|
||||
Advanced: true,
|
||||
Order: 21,
|
||||
},
|
||||
"cache_type_k": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (K)",
|
||||
Description: "Quantization type for key cache (e.g. f16, q8_0, q4_0)",
|
||||
Component: "select",
|
||||
Options: CacheTypeOptions,
|
||||
VRAMImpact: true,
|
||||
Advanced: true,
|
||||
Order: 22,
|
||||
},
|
||||
"cache_type_v": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (V)",
|
||||
Description: "Quantization type for value cache",
|
||||
Component: "select",
|
||||
Options: CacheTypeOptions,
|
||||
VRAMImpact: true,
|
||||
Advanced: true,
|
||||
Order: 23,
|
||||
},
|
||||
|
||||
// --- Parameters ---
|
||||
"parameters.temperature": {
|
||||
Section: "parameters",
|
||||
Label: "Temperature",
|
||||
Description: "Sampling temperature (higher = more creative, lower = more deterministic)",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(2),
|
||||
Step: f64(0.05),
|
||||
Order: 30,
|
||||
},
|
||||
"parameters.top_p": {
|
||||
Section: "parameters",
|
||||
Label: "Top P",
|
||||
Description: "Nucleus sampling threshold",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.01),
|
||||
Order: 31,
|
||||
},
|
||||
"parameters.top_k": {
|
||||
Section: "parameters",
|
||||
Label: "Top K",
|
||||
Description: "Top-K sampling: consider only the K most likely tokens",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 32,
|
||||
},
|
||||
"parameters.max_tokens": {
|
||||
Section: "parameters",
|
||||
Label: "Max Tokens",
|
||||
Description: "Maximum number of tokens to generate (0 = unlimited)",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 33,
|
||||
},
|
||||
"parameters.repeat_penalty": {
|
||||
Section: "parameters",
|
||||
Label: "Repeat Penalty",
|
||||
Description: "Penalize repeated tokens (1.0 = no penalty)",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Advanced: true,
|
||||
Order: 34,
|
||||
},
|
||||
"parameters.seed": {
|
||||
Section: "parameters",
|
||||
Label: "Seed",
|
||||
Description: "Random seed (-1 = random)",
|
||||
Component: "number",
|
||||
Advanced: true,
|
||||
Order: 35,
|
||||
},
|
||||
|
||||
// --- Templates ---
|
||||
"template.chat": {
|
||||
Section: "templates",
|
||||
Label: "Chat Template",
|
||||
Description: "Go template for chat completion requests",
|
||||
Component: "code-editor",
|
||||
Order: 40,
|
||||
},
|
||||
"template.chat_message": {
|
||||
Section: "templates",
|
||||
Label: "Chat Message Template",
|
||||
Description: "Go template for individual chat messages",
|
||||
Component: "code-editor",
|
||||
Order: 41,
|
||||
},
|
||||
"template.completion": {
|
||||
Section: "templates",
|
||||
Label: "Completion Template",
|
||||
Description: "Go template for completion requests",
|
||||
Component: "code-editor",
|
||||
Order: 42,
|
||||
},
|
||||
"template.use_tokenizer_template": {
|
||||
Section: "templates",
|
||||
Label: "Use Tokenizer Template",
|
||||
Description: "Use the chat template from the model's tokenizer config",
|
||||
Order: 43,
|
||||
},
|
||||
|
||||
// --- Pipeline ---
|
||||
"pipeline.llm": {
|
||||
Section: "pipeline",
|
||||
Label: "LLM Model",
|
||||
Description: "Model to use for LLM inference in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 60,
|
||||
},
|
||||
"pipeline.tts": {
|
||||
Section: "pipeline",
|
||||
Label: "TTS Model",
|
||||
Description: "Model to use for text-to-speech in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsTTS,
|
||||
Order: 61,
|
||||
},
|
||||
"pipeline.transcription": {
|
||||
Section: "pipeline",
|
||||
Label: "Transcription Model",
|
||||
Description: "Model to use for speech-to-text in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsTranscript,
|
||||
Order: 62,
|
||||
},
|
||||
"pipeline.vad": {
|
||||
Section: "pipeline",
|
||||
Label: "VAD Model",
|
||||
Description: "Model to use for voice activity detection in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsVAD,
|
||||
Order: 63,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
Section: "functions",
|
||||
Label: "Parallel Calls",
|
||||
Description: "Allow the LLM to return multiple function calls in one response",
|
||||
Order: 70,
|
||||
},
|
||||
"function.grammar.mixed_mode": {
|
||||
Section: "functions",
|
||||
Label: "Mixed Mode",
|
||||
Description: "Allow the LLM to return both text and function calls",
|
||||
Order: 71,
|
||||
},
|
||||
"function.grammar.disable": {
|
||||
Section: "functions",
|
||||
Label: "Disable Grammar",
|
||||
Description: "Disable grammar-constrained generation for function calls",
|
||||
Advanced: true,
|
||||
Order: 72,
|
||||
},
|
||||
|
||||
// --- TTS ---
|
||||
"tts.voice": {
|
||||
Section: "tts",
|
||||
Label: "Voice",
|
||||
Description: "Default voice for TTS output",
|
||||
Component: "input",
|
||||
Order: 90,
|
||||
},
|
||||
|
||||
// --- Diffusers ---
|
||||
"diffusers.pipeline_type": {
|
||||
Section: "diffusers",
|
||||
Label: "Pipeline Type",
|
||||
Description: "Diffusers pipeline type (e.g. StableDiffusionPipeline)",
|
||||
Component: "select",
|
||||
Options: DiffusersPipelineOptions,
|
||||
Order: 80,
|
||||
},
|
||||
"diffusers.scheduler_type": {
|
||||
Section: "diffusers",
|
||||
Label: "Scheduler Type",
|
||||
Description: "Noise scheduler type",
|
||||
Component: "select",
|
||||
Options: DiffusersSchedulerOptions,
|
||||
Order: 81,
|
||||
},
|
||||
"diffusers.cuda": {
|
||||
Section: "diffusers",
|
||||
Label: "CUDA",
|
||||
Description: "Enable CUDA for diffusers",
|
||||
Order: 82,
|
||||
},
|
||||
}
|
||||
}
|
||||
83
core/config/meta/types.go
Normal file
83
core/config/meta/types.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package meta
|
||||
|
||||
// FieldMeta describes a single configuration field for UI rendering and agent discovery.
|
||||
type FieldMeta struct {
|
||||
Path string `json:"path"` // dot-path: "context_size", "function.grammar.parallel_calls"
|
||||
YAMLKey string `json:"yaml_key"` // leaf yaml key
|
||||
GoType string `json:"go_type"` // "*int", "string", "[]string"
|
||||
UIType string `json:"ui_type"` // "string", "int", "float", "bool", "[]string", "map", "object"
|
||||
Pointer bool `json:"pointer,omitempty"` // true = nil means "not set"
|
||||
Section string `json:"section"` // "general", "llm", "templates", etc.
|
||||
Label string `json:"label"` // human-readable label
|
||||
Description string `json:"description,omitempty"` // help text
|
||||
Component string `json:"component"` // "input", "number", "toggle", "select", "slider", etc.
|
||||
Placeholder string `json:"placeholder,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
Min *float64 `json:"min,omitempty"`
|
||||
Max *float64 `json:"max,omitempty"`
|
||||
Step *float64 `json:"step,omitempty"`
|
||||
Options []FieldOption `json:"options,omitempty"`
|
||||
|
||||
AutocompleteProvider string `json:"autocomplete_provider,omitempty"` // "backends", "models:chat", etc.
|
||||
VRAMImpact bool `json:"vram_impact,omitempty"`
|
||||
Advanced bool `json:"advanced,omitempty"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
// FieldOption represents a choice in a select/enum field.
|
||||
type FieldOption struct {
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
// Section groups related fields in the UI.
|
||||
type Section struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
// ConfigMetadata is the top-level response for the metadata API.
|
||||
type ConfigMetadata struct {
|
||||
Sections []Section `json:"sections"`
|
||||
Fields []FieldMeta `json:"fields"`
|
||||
}
|
||||
|
||||
// FieldMetaOverride holds registry overrides that are merged on top of
|
||||
// the reflection-discovered defaults. Only non-zero fields override.
|
||||
type FieldMetaOverride struct {
|
||||
Section string
|
||||
Label string
|
||||
Description string
|
||||
Component string
|
||||
Placeholder string
|
||||
Default any
|
||||
Min *float64
|
||||
Max *float64
|
||||
Step *float64
|
||||
Options []FieldOption
|
||||
AutocompleteProvider string
|
||||
VRAMImpact bool
|
||||
Advanced bool
|
||||
Order int
|
||||
}
|
||||
|
||||
// DefaultSections defines the well-known config sections in display order.
|
||||
func DefaultSections() []Section {
|
||||
return []Section{
|
||||
{ID: "general", Label: "General", Icon: "settings", Order: 0},
|
||||
{ID: "llm", Label: "LLM", Icon: "cpu", Order: 10},
|
||||
{ID: "parameters", Label: "Parameters", Icon: "sliders", Order: 20},
|
||||
{ID: "templates", Label: "Templates", Icon: "file-text", Order: 30},
|
||||
{ID: "functions", Label: "Functions / Tools", Icon: "tool", Order: 40},
|
||||
{ID: "reasoning", Label: "Reasoning", Icon: "brain", Order: 45},
|
||||
{ID: "diffusers", Label: "Diffusers", Icon: "image", Order: 50},
|
||||
{ID: "tts", Label: "TTS", Icon: "volume-2", Order: 55},
|
||||
{ID: "pipeline", Label: "Pipeline", Icon: "git-merge", Order: 60},
|
||||
{ID: "grpc", Label: "gRPC", Icon: "server", Order: 65},
|
||||
{ID: "agent", Label: "Agent", Icon: "bot", Order: 70},
|
||||
{ID: "mcp", Label: "MCP", Icon: "plug", Order: 75},
|
||||
{ID: "other", Label: "Other", Icon: "more-horizontal", Order: 100},
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,8 @@ type ModelConfig struct {
|
||||
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
||||
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
|
||||
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
|
||||
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
@@ -548,6 +550,16 @@ func (c *ModelConfig) GetModelTemplate() string {
|
||||
return c.modelTemplate
|
||||
}
|
||||
|
||||
// IsDisabled returns true if the model is disabled
|
||||
func (c *ModelConfig) IsDisabled() bool {
|
||||
return c.Disabled != nil && *c.Disabled
|
||||
}
|
||||
|
||||
// IsPinned returns true if the model is pinned (excluded from idle unloading and eviction)
|
||||
func (c *ModelConfig) IsPinned() bool {
|
||||
return c.Pinned != nil && *c.Pinned
|
||||
}
|
||||
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
@@ -705,7 +717,8 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
|
||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||
if c.Backend != "rfdetr" {
|
||||
detectionBackends := []string{"rfdetr", "sam3-cpp"}
|
||||
if !slices.Contains(detectionBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ type RuntimeSettings struct {
|
||||
// Backend management
|
||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
AutoUpgradeBackends *bool `json:"auto_upgrade_backends,omitempty"` // Automatically upgrade backends when new versions are detected
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
|
||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||
|
||||
@@ -20,12 +20,19 @@ type BackendMetadata struct {
|
||||
GalleryURL string `json:"gallery_url,omitempty"`
|
||||
// InstalledAt is the timestamp when the backend was installed
|
||||
InstalledAt string `json:"installed_at,omitempty"`
|
||||
// Version is the version of the backend at install time
|
||||
Version string `json:"version,omitempty"`
|
||||
// URI is the original URI used to install the backend
|
||||
URI string `json:"uri,omitempty"`
|
||||
// Digest is the OCI image digest at install time (for upgrade detection)
|
||||
Digest string `json:"digest,omitempty"`
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
Metadata `json:",inline" yaml:",inline"`
|
||||
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
||||
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
||||
Version string `json:"version,omitempty" yaml:"version,omitempty"`
|
||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||
}
|
||||
@@ -71,6 +78,10 @@ func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool
|
||||
return true
|
||||
}
|
||||
|
||||
if systemState.CapabilityFilterDisabled() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Meta backends are compatible if the system capability matches one of the keys
|
||||
if m.IsMeta() {
|
||||
capability := systemState.Capability(m.CapabilitiesMap)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
cp "github.com/otiai10/copy"
|
||||
@@ -158,6 +159,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
|
||||
Name: name,
|
||||
GalleryURL: backend.Gallery.URL,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Version: bestBackend.Version,
|
||||
}
|
||||
|
||||
if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil {
|
||||
@@ -279,6 +281,18 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
Name: name,
|
||||
GalleryURL: config.Gallery.URL,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Version: config.Version,
|
||||
URI: string(uri),
|
||||
}
|
||||
|
||||
// Record the OCI digest for upgrade detection (non-fatal on failure)
|
||||
if uri.LooksLikeOCI() {
|
||||
digest, digestErr := oci.GetImageDigest(string(uri), "", nil, nil)
|
||||
if digestErr != nil {
|
||||
xlog.Warn("Failed to get OCI image digest for backend", "uri", string(uri), "error", digestErr)
|
||||
} else {
|
||||
metadata.Digest = digest
|
||||
}
|
||||
}
|
||||
|
||||
if config.Alias != "" {
|
||||
@@ -300,14 +314,29 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
|
||||
backend, ok := backends.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||
// Not found by direct key — try matching by gallery name (metadata.Name)
|
||||
// The UI may send gallery-style names like "localai@llama-cpp" which
|
||||
// don't match the directory-based keys used in the backends map.
|
||||
for _, b := range backends {
|
||||
if b.Metadata != nil && b.Metadata.Name == name && !b.IsMeta {
|
||||
backend = b
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
if backend.IsSystem {
|
||||
return fmt.Errorf("system backend %q cannot be deleted", name)
|
||||
}
|
||||
|
||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
// Use the backend's actual Name (directory key) for path resolution,
|
||||
// not the caller-supplied name which may be a gallery-style name.
|
||||
dirName := backend.Name
|
||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, dirName)
|
||||
|
||||
// check if the backend dir exists
|
||||
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
|
||||
@@ -325,7 +354,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if metadata != nil && metadata.Alias == name {
|
||||
if metadata != nil && (metadata.Alias == name || metadata.Alias == dirName) {
|
||||
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
|
||||
foundBackend = true
|
||||
break
|
||||
@@ -358,11 +387,13 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
}
|
||||
|
||||
type SystemBackend struct {
|
||||
Name string
|
||||
RunFile string
|
||||
IsMeta bool
|
||||
IsSystem bool
|
||||
Metadata *BackendMetadata
|
||||
Name string
|
||||
RunFile string
|
||||
IsMeta bool
|
||||
IsSystem bool
|
||||
Metadata *BackendMetadata
|
||||
UpgradeAvailable bool `json:"upgrade_available,omitempty"`
|
||||
AvailableVersion string `json:"available_version,omitempty"`
|
||||
}
|
||||
|
||||
type SystemBackends map[string]SystemBackend
|
||||
|
||||
118
core/gallery/backends_version_test.go
Normal file
118
core/gallery/backends_version_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Backend versioning", func() {
|
||||
var tempDir string
|
||||
var systemState *system.SystemState
|
||||
var modelLoader *model.ModelLoader
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "gallery-version-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err = system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
It("records version in metadata when installing a backend with a version", func() {
|
||||
// Create a fake backend source directory with a run.sh
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend"
|
||||
backend.URI = srcDir
|
||||
backend.Version = "1.2.3"
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Read the metadata file and check version
|
||||
metadataPath := filepath.Join(tempDir, "test-backend", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(metadata["version"]).To(Equal("1.2.3"))
|
||||
})
|
||||
|
||||
It("records URI in metadata", func() {
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend-uri"
|
||||
backend.URI = srcDir
|
||||
backend.Version = "2.0.0"
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metadataPath := filepath.Join(tempDir, "test-backend-uri", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(metadata["uri"]).To(Equal(srcDir))
|
||||
})
|
||||
|
||||
It("omits version key when version is empty", func() {
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend-noversion"
|
||||
backend.URI = srcDir
|
||||
// Version intentionally left empty
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metadataPath := filepath.Join(tempDir, "test-backend-noversion", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// omitempty should exclude the version key entirely
|
||||
_, hasVersion := metadata["version"]
|
||||
Expect(hasVersion).To(BeFalse())
|
||||
})
|
||||
})
|
||||
237
core/gallery/upgrade.go
Normal file
237
core/gallery/upgrade.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
cp "github.com/otiai10/copy"
|
||||
)
|
||||
|
||||
// UpgradeInfo holds details about an available backend upgrade.
|
||||
type UpgradeInfo struct {
|
||||
BackendName string `json:"backend_name"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
AvailableVersion string `json:"available_version"`
|
||||
InstalledDigest string `json:"installed_digest,omitempty"`
|
||||
AvailableDigest string `json:"available_digest,omitempty"`
|
||||
}
|
||||
|
||||
// CheckBackendUpgrades compares installed backends against gallery entries
|
||||
// and returns a map of backend names to UpgradeInfo for those that have
|
||||
// newer versions or different OCI digests available.
|
||||
func CheckBackendUpgrades(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState) (map[string]UpgradeInfo, error) {
|
||||
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list available backends: %w", err)
|
||||
}
|
||||
|
||||
installedBackends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list installed backends: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[string]UpgradeInfo)
|
||||
|
||||
for _, installed := range installedBackends {
|
||||
// Skip system backends — they are managed outside the gallery
|
||||
if installed.IsSystem {
|
||||
continue
|
||||
}
|
||||
if installed.Metadata == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find matching gallery entry by metadata name
|
||||
galleryEntry := FindGalleryElement(galleryBackends, installed.Metadata.Name)
|
||||
if galleryEntry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
installedVersion := installed.Metadata.Version
|
||||
galleryVersion := galleryEntry.Version
|
||||
|
||||
// If both sides have versions, compare them
|
||||
if galleryVersion != "" && installedVersion != "" {
|
||||
if galleryVersion != installedVersion {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledVersion: installedVersion,
|
||||
AvailableVersion: galleryVersion,
|
||||
}
|
||||
}
|
||||
// Versions match — no upgrade needed
|
||||
continue
|
||||
}
|
||||
|
||||
// Gallery has a version but installed doesn't — this happens for backends
|
||||
// installed before version tracking was added. Flag as upgradeable so
|
||||
// users can re-install to pick up version metadata.
|
||||
if galleryVersion != "" && installedVersion == "" {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledVersion: "",
|
||||
AvailableVersion: galleryVersion,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Fall back to OCI digest comparison when versions are unavailable
|
||||
if downloader.URI(galleryEntry.URI).LooksLikeOCI() {
|
||||
remoteDigest, err := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to get remote OCI digest for upgrade check", "backend", installed.Metadata.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
// If we have a stored digest, compare; otherwise any remote digest
|
||||
// means we can't confirm we're up to date — flag as upgradeable
|
||||
if installed.Metadata.Digest == "" || remoteDigest != installed.Metadata.Digest {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledDigest: installed.Metadata.Digest,
|
||||
AvailableDigest: remoteDigest,
|
||||
}
|
||||
}
|
||||
}
|
||||
// No version info and non-OCI URI — cannot determine, skip
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpgradeBackend upgrades a single backend to the latest gallery version using
|
||||
// an atomic swap with backup-based rollback on failure.
|
||||
func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, galleries []config.Gallery, backendName string, downloadStatus func(string, string, string, float64)) error {
|
||||
// Look up the installed backend
|
||||
installedBackends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list installed backends: %w", err)
|
||||
}
|
||||
|
||||
installed, ok := installedBackends.Get(backendName)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", backendName, ErrBackendNotFound)
|
||||
}
|
||||
|
||||
if installed.IsSystem {
|
||||
return fmt.Errorf("system backend %q cannot be upgraded via gallery", backendName)
|
||||
}
|
||||
|
||||
// If this is a meta backend, recursively upgrade the concrete backend it points to
|
||||
if installed.Metadata != nil && installed.Metadata.MetaBackendFor != "" {
|
||||
xlog.Info("Meta backend detected, upgrading concrete backend", "meta", backendName, "concrete", installed.Metadata.MetaBackendFor)
|
||||
return UpgradeBackend(ctx, systemState, modelLoader, galleries, installed.Metadata.MetaBackendFor, downloadStatus)
|
||||
}
|
||||
|
||||
// Find the gallery entry
|
||||
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list available backends: %w", err)
|
||||
}
|
||||
|
||||
galleryEntry := FindGalleryElement(galleryBackends, backendName)
|
||||
if galleryEntry == nil {
|
||||
return fmt.Errorf("no gallery entry found for backend %q", backendName)
|
||||
}
|
||||
|
||||
backendPath := filepath.Join(systemState.Backend.BackendsPath, backendName)
|
||||
tmpPath := backendPath + ".upgrade-tmp"
|
||||
backupPath := backendPath + ".backup"
|
||||
|
||||
// Clean up any stale tmp/backup dirs from prior attempts
|
||||
os.RemoveAll(tmpPath)
|
||||
os.RemoveAll(backupPath)
|
||||
|
||||
// Step 1: Download the new backend into the tmp directory
|
||||
if err := os.MkdirAll(tmpPath, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create upgrade tmp dir: %w", err)
|
||||
}
|
||||
|
||||
uri := downloader.URI(galleryEntry.URI)
|
||||
if uri.LooksLikeDir() {
|
||||
if err := cp.Copy(string(uri), tmpPath); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to copy backend from directory: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := uri.DownloadFileWithContext(ctx, tmpPath, "", 1, 1, downloadStatus); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to download backend: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Validate — check that run.sh exists in the new content
|
||||
newRunFile := filepath.Join(tmpPath, runFile)
|
||||
if _, err := os.Stat(newRunFile); os.IsNotExist(err) {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("upgrade validation failed: run.sh not found in new backend")
|
||||
}
|
||||
|
||||
// Step 3: Atomic swap — rename current to backup, then tmp to current
|
||||
if err := os.Rename(backendPath, backupPath); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to move current backend to backup: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, backendPath); err != nil {
|
||||
// Restore backup on failure
|
||||
xlog.Error("Failed to move new backend into place, restoring backup", "error", err)
|
||||
if restoreErr := os.Rename(backupPath, backendPath); restoreErr != nil {
|
||||
xlog.Error("Failed to restore backup", "error", restoreErr)
|
||||
}
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to move new backend into place: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Write updated metadata, preserving alias from old metadata
|
||||
var oldAlias string
|
||||
if installed.Metadata != nil {
|
||||
oldAlias = installed.Metadata.Alias
|
||||
}
|
||||
|
||||
newMetadata := &BackendMetadata{
|
||||
Name: backendName,
|
||||
Version: galleryEntry.Version,
|
||||
URI: galleryEntry.URI,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Alias: oldAlias,
|
||||
}
|
||||
|
||||
if galleryEntry.Gallery.URL != "" {
|
||||
newMetadata.GalleryURL = galleryEntry.Gallery.URL
|
||||
}
|
||||
|
||||
// Record OCI digest if applicable (non-fatal on failure)
|
||||
if uri.LooksLikeOCI() {
|
||||
digest, digestErr := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||
if digestErr != nil {
|
||||
xlog.Warn("Failed to get OCI image digest after upgrade", "uri", galleryEntry.URI, "error", digestErr)
|
||||
} else {
|
||||
newMetadata.Digest = digest
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeBackendMetadata(backendPath, newMetadata); err != nil {
|
||||
// Metadata write failure is not worth rolling back the entire upgrade
|
||||
xlog.Error("Failed to write metadata after upgrade", "error", err)
|
||||
}
|
||||
|
||||
// Step 5: Re-register backends so the model loader picks up any changes
|
||||
if err := RegisterBackends(systemState, modelLoader); err != nil {
|
||||
xlog.Warn("Failed to re-register backends after upgrade", "error", err)
|
||||
}
|
||||
|
||||
// Step 6: Remove backup
|
||||
os.RemoveAll(backupPath)
|
||||
|
||||
xlog.Info("Backend upgraded successfully", "backend", backendName, "version", galleryEntry.Version)
|
||||
return nil
|
||||
}
|
||||
219
core/gallery/upgrade_test.go
Normal file
219
core/gallery/upgrade_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var _ = Describe("Upgrade Detection and Execution", func() {
|
||||
var (
|
||||
tempDir string
|
||||
backendsPath string
|
||||
galleryPath string
|
||||
systemState *system.SystemState
|
||||
galleries []config.Gallery
|
||||
)
|
||||
|
||||
// installBackendWithVersion creates a fake installed backend directory with
|
||||
// the given name, version, and optional run.sh content.
|
||||
installBackendWithVersion := func(name, version string, runContent ...string) {
|
||||
dir := filepath.Join(backendsPath, name)
|
||||
Expect(os.MkdirAll(dir, 0750)).To(Succeed())
|
||||
|
||||
content := "#!/bin/sh\necho ok"
|
||||
if len(runContent) > 0 {
|
||||
content = runContent[0]
|
||||
}
|
||||
Expect(os.WriteFile(filepath.Join(dir, "run.sh"), []byte(content), 0755)).To(Succeed())
|
||||
|
||||
metadata := BackendMetadata{
|
||||
Name: name,
|
||||
Version: version,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, err := json.MarshalIndent(metadata, "", " ")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(dir, "metadata.json"), data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
// writeGalleryYAML writes a gallery YAML file with the given backends.
|
||||
writeGalleryYAML := func(backends []GalleryBackend) {
|
||||
data, err := yaml.Marshal(backends)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(galleryPath, data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "upgrade-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backendsPath = tempDir
|
||||
|
||||
galleryPath = filepath.Join(tempDir, "gallery.yaml")
|
||||
|
||||
// Write a default empty gallery
|
||||
writeGalleryYAML([]GalleryBackend{})
|
||||
|
||||
galleries = []config.Gallery{
|
||||
{
|
||||
Name: "test-gallery",
|
||||
URL: "file://" + galleryPath,
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err = system.GetSystemState(
|
||||
system.WithBackendPath(backendsPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Describe("CheckBackendUpgrades", func() {
|
||||
It("should detect upgrade when gallery version differs from installed version", func() {
|
||||
// Install a backend at v1.0.0
|
||||
installBackendWithVersion("my-backend", "1.0.0")
|
||||
|
||||
// Gallery advertises v2.0.0
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(HaveKey("my-backend"))
|
||||
Expect(upgrades["my-backend"].InstalledVersion).To(Equal("1.0.0"))
|
||||
Expect(upgrades["my-backend"].AvailableVersion).To(Equal("2.0.0"))
|
||||
})
|
||||
|
||||
It("should NOT flag upgrade when versions match", func() {
|
||||
installBackendWithVersion("my-backend", "2.0.0")
|
||||
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should skip backends without version info and without OCI digest", func() {
|
||||
// Install without version
|
||||
installBackendWithVersion("my-backend", "")
|
||||
|
||||
// Gallery also without version
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UpgradeBackend", func() {
|
||||
It("should replace backend directory and update metadata", func() {
|
||||
// Install v1
|
||||
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||
|
||||
// Create a source directory with v2 content
|
||||
srcDir := filepath.Join(tempDir, "v2-source")
|
||||
Expect(os.MkdirAll(srcDir, 0750)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho v2"), 0755)).To(Succeed())
|
||||
|
||||
// Gallery points to the v2 source dir
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: srcDir,
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Verify run.sh was updated
|
||||
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal("#!/bin/sh\necho v2"))
|
||||
|
||||
// Verify metadata was updated
|
||||
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
var meta BackendMetadata
|
||||
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||
Expect(meta.Version).To(Equal("2.0.0"))
|
||||
Expect(meta.Name).To(Equal("my-backend"))
|
||||
})
|
||||
|
||||
It("should restore backup on failure", func() {
|
||||
// Install v1
|
||||
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||
|
||||
// Gallery points to a nonexistent path (no run.sh will be found)
|
||||
nonExistentDir := filepath.Join(tempDir, "does-not-exist")
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: nonExistentDir,
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
// Verify v1 is still intact
|
||||
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal("#!/bin/sh\necho v1"))
|
||||
|
||||
// Verify metadata still says v1
|
||||
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
var meta BackendMetadata
|
||||
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||
Expect(meta.Version).To(Equal("1.0.0"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -52,9 +52,42 @@ var quietPaths = []string{"/api/operations", "/api/resources", "/healthz", "/rea
|
||||
// @license.name MIT
|
||||
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
|
||||
// @BasePath /
|
||||
// @schemes http https
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
// @name Authorization
|
||||
// @tag.name inference
|
||||
// @tag.description Chat completions, text completions, edits, and responses (OpenAI-compatible)
|
||||
// @tag.name embeddings
|
||||
// @tag.description Vector embeddings (OpenAI-compatible)
|
||||
// @tag.name audio
|
||||
// @tag.description Text-to-speech, transcription, voice activity detection, sound generation
|
||||
// @tag.name images
|
||||
// @tag.description Image generation and inpainting
|
||||
// @tag.name video
|
||||
// @tag.description Video generation from prompts
|
||||
// @tag.name detection
|
||||
// @tag.description Object detection in images
|
||||
// @tag.name tokenize
|
||||
// @tag.description Tokenization and token metrics
|
||||
// @tag.name models
|
||||
// @tag.description Model gallery browsing, installation, deletion, and listing
|
||||
// @tag.name backends
|
||||
// @tag.description Backend gallery browsing, installation, deletion, and listing
|
||||
// @tag.name config
|
||||
// @tag.description Model configuration metadata, autocomplete, PATCH updates, VRAM estimation
|
||||
// @tag.name monitoring
|
||||
// @tag.description Prometheus metrics, backend status, system information
|
||||
// @tag.name mcp
|
||||
// @tag.description Model Context Protocol — tool-augmented chat with MCP servers
|
||||
// @tag.name agent-jobs
|
||||
// @tag.description Agent task and job management
|
||||
// @tag.name p2p
|
||||
// @tag.description Peer-to-peer networking nodes and tokens
|
||||
// @tag.name rerank
|
||||
// @tag.description Document reranking
|
||||
// @tag.name instructions
|
||||
// @tag.description API instruction discovery — browse instruction areas and get endpoint guides
|
||||
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
@@ -358,9 +391,13 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOllamaRoutes(e, requestExtractor, application)
|
||||
if application.ApplicationConfig().OllamaAPIRootEndpoint {
|
||||
routes.RegisterOllamaRootEndpoint(e)
|
||||
}
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
|
||||
// Serve React SPA from / with SPA fallback via 404 handler
|
||||
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")
|
||||
|
||||
@@ -956,8 +956,7 @@ parameters:
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// A model called "bert" can be present in the model directory depending on the order of the tests
|
||||
Expect(len(models.Models)).To(BeNumerically(">=", 8))
|
||||
Expect(len(models.Models)).To(BeNumerically(">=", 7))
|
||||
})
|
||||
It("can generate completions via ggml", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
@@ -979,6 +978,42 @@ parameters:
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not duplicate the first content token in streaming chat completions", Label("llama-gguf", "llama-gguf-stream"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
var contentParts []string
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if len(chunk.Choices) > 0 {
|
||||
delta := chunk.Choices[0].Delta.Content
|
||||
if delta != "" {
|
||||
contentParts = append(contentParts, delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(contentParts).ToNot(BeEmpty(), "Expected streaming content tokens")
|
||||
// The first content token should appear exactly once.
|
||||
// A bug in grpc-server.cpp caused the role-init array element
|
||||
// to get the same ChatDelta stamped, duplicating the first token.
|
||||
if len(contentParts) >= 2 {
|
||||
Expect(contentParts[0]).ToNot(Equal(contentParts[1]),
|
||||
"First content token was duplicated: %v", contentParts[:2])
|
||||
}
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test only on linux")
|
||||
|
||||
@@ -3,6 +3,8 @@ package anthropic
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
// MessagesEndpoint is the Anthropic Messages API endpoint
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// @Summary Generate a message response for the given messages and model.
|
||||
// @Tags inference
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
@@ -357,7 +360,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// Send initial content_block_start event
|
||||
contentBlockStart := schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
|
||||
}
|
||||
sendAnthropicSSE(c, contentBlockStart)
|
||||
@@ -365,7 +368,33 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// Collect tool calls for MCP execution
|
||||
var collectedToolCalls []functions.FuncCallResults
|
||||
|
||||
// SSE keepalive: send comment pings every 3s until the first token arrives.
|
||||
// This prevents clients (e.g. Claude Code) from timing out while the model loads or processes the prompt.
|
||||
firstTokenReceived := make(chan struct{})
|
||||
keepaliveDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(keepaliveDone)
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-firstTokenReceived:
|
||||
return
|
||||
case <-c.Request().Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
fmt.Fprintf(c.Response().Writer, "event: ping\ndata: {\"type\": \"ping\"}\n\n")
|
||||
c.Response().Flush()
|
||||
}
|
||||
}
|
||||
}()
|
||||
firstTokenOnce := sync.Once{}
|
||||
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
firstTokenOnce.Do(func() {
|
||||
close(firstTokenReceived)
|
||||
<-keepaliveDone // wait for keepalive goroutine to exit before writing
|
||||
})
|
||||
accumulatedContent += token
|
||||
|
||||
if shouldUseFn {
|
||||
@@ -376,7 +405,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
@@ -386,7 +415,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
tc := toolCalls[i]
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fmt.Sprintf("toolu_%s_%d", id, i),
|
||||
@@ -395,7 +424,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: tc.Arguments,
|
||||
@@ -403,7 +432,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
}
|
||||
@@ -413,10 +442,10 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
}
|
||||
|
||||
if !inToolCall {
|
||||
if !inToolCall && token != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
@@ -432,6 +461,11 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
openAIReq.Metadata = input.Metadata
|
||||
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
|
||||
// Stop the keepalive goroutine now that inference is done
|
||||
firstTokenOnce.Do(func() { close(firstTokenReceived) })
|
||||
<-keepaliveDone
|
||||
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
@@ -444,9 +478,68 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
return nil
|
||||
}
|
||||
|
||||
// Also check chat deltas for tool calls
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
collectedToolCalls = deltaToolCalls
|
||||
// Check chat deltas from C++ autoparser — when active, the raw
|
||||
// message is cleared and content/tool calls arrive via ChatDeltas.
|
||||
if len(chatDeltas) > 0 {
|
||||
deltaContent := functions.ContentFromChatDeltas(chatDeltas)
|
||||
deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas)
|
||||
|
||||
// Emit text content from ChatDeltas only when the tokenCallback
|
||||
// didn't already stream it (autoparser clears raw text, so
|
||||
// accumulatedContent will be empty in that case).
|
||||
if deltaContent != "" && !inToolCall && accumulatedContent == "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: deltaContent,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Emit tool_use blocks from ChatDeltas
|
||||
if len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
collectedToolCalls = deltaToolCalls
|
||||
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
}
|
||||
for i, tc := range deltaToolCalls {
|
||||
toolCallID := tc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("toolu_%s_%d", id, i)
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
Name: tc.Name,
|
||||
},
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: tc.Arguments,
|
||||
},
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
toolCallsEmitted++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MCP streaming tool execution: if we collected MCP tool calls, execute and loop
|
||||
@@ -516,7 +609,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// Close the text content block
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
@@ -528,7 +621,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
@@ -537,7 +630,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: fc.Arguments,
|
||||
@@ -545,7 +638,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
toolCallsEmitted++
|
||||
@@ -557,7 +650,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !inToolCall {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
Index: intPtr(0),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -598,6 +691,8 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
return tools
|
||||
}
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
||||
// @Summary Generates audio from the input text.
|
||||
// @Tags audio
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
// @Summary Generates audio from the input text.
|
||||
// @Tags audio
|
||||
// @Param voice-id path string true "Account ID"
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/)
|
||||
// @Summary Reranks a list of phrases by relevance to a given text query.
|
||||
// @Tags rerank
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
|
||||
@@ -30,6 +30,15 @@ func getJobService(app *application.Application, c echo.Context) *agentpool.Agen
|
||||
return jobSvc
|
||||
}
|
||||
|
||||
// CreateTaskEndpoint creates a new agent task definition.
|
||||
// @Summary Create a new agent task
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.Task true "Task definition"
|
||||
// @Success 201 {object} map[string]string "id"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks [post]
|
||||
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var task schema.Task
|
||||
@@ -46,6 +55,17 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskEndpoint updates an existing agent task.
|
||||
// @Summary Update an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Param request body schema.Task true "Updated task definition"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [put]
|
||||
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -65,6 +85,14 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteTaskEndpoint deletes an agent task.
|
||||
// @Summary Delete an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [delete]
|
||||
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -79,6 +107,13 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListTasksEndpoint lists all agent tasks for the current user.
|
||||
// @Summary List agent tasks
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||
// @Success 200 {object} []schema.Task "tasks"
|
||||
// @Router /api/agent/tasks [get]
|
||||
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
jobSvc := getJobService(app, c)
|
||||
@@ -121,6 +156,14 @@ func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetTaskEndpoint returns a single agent task by ID.
|
||||
// @Summary Get an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} schema.Task "task"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [get]
|
||||
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -133,6 +176,15 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteJobEndpoint creates and runs a new job for a task.
|
||||
// @Summary Execute an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.JobExecutionRequest true "Job execution request"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/execute [post]
|
||||
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.JobExecutionRequest
|
||||
@@ -168,6 +220,14 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetJobEndpoint returns a single job by ID.
|
||||
// @Summary Get an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} schema.Job "job"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id} [get]
|
||||
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -180,6 +240,16 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListJobsEndpoint lists jobs, optionally filtered by task or status.
|
||||
// @Summary List agent jobs
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param task_id query string false "Filter by task ID"
|
||||
// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
|
||||
// @Param limit query integer false "Max number of jobs to return"
|
||||
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||
// @Success 200 {object} []schema.Job "jobs"
|
||||
// @Router /api/agent/jobs [get]
|
||||
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var taskID *string
|
||||
@@ -241,6 +311,15 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// CancelJobEndpoint cancels a running job.
|
||||
// @Summary Cancel an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id}/cancel [post]
|
||||
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -255,6 +334,14 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteJobEndpoint deletes a job by ID.
|
||||
// @Summary Delete an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id} [delete]
|
||||
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -269,6 +356,17 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTaskByNameEndpoint looks up a task by name and executes it.
|
||||
// @Summary Execute an agent task by name
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "Task name"
|
||||
// @Param parameters body object false "Optional template parameters"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{name}/execute [post]
|
||||
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
name := c.Param("name")
|
||||
|
||||
489
core/http/endpoints/localai/api_instructions.go
Normal file
489
core/http/endpoints/localai/api_instructions.go
Normal file
@@ -0,0 +1,489 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/swagger"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const swaggerDefsPrefix = "#/definitions/"
|
||||
|
||||
// instructionDef is a lightweight instruction definition that maps to swagger tags.
|
||||
type instructionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
Intro string `json:"-"` // brief context not in swagger
|
||||
}
|
||||
|
||||
var instructionDefs = []instructionDef{
|
||||
{
|
||||
Name: "chat-inference",
|
||||
Description: "OpenAI-compatible chat completions, text completions, and embeddings",
|
||||
Tags: []string{"inference", "embeddings"},
|
||||
Intro: "Set \"stream\": true for SSE streaming. Supports tool/function calling when the model config has function templates configured.",
|
||||
},
|
||||
{
|
||||
Name: "audio",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, and sound generation",
|
||||
Tags: []string{"audio"},
|
||||
},
|
||||
{
|
||||
Name: "images",
|
||||
Description: "Image generation and inpainting",
|
||||
Tags: []string{"images"},
|
||||
},
|
||||
{
|
||||
Name: "model-management",
|
||||
Description: "Browse the gallery, install, delete, and manage models and backends",
|
||||
Tags: []string{"models", "backends"},
|
||||
},
|
||||
{
|
||||
Name: "config-management",
|
||||
Description: "Discover, read, and modify model configuration fields with VRAM estimation",
|
||||
Tags: []string{"config"},
|
||||
Intro: "Fields with static options include an \"options\" array in metadata. Fields with dynamic values have an \"autocomplete_provider\" for runtime lookup.",
|
||||
},
|
||||
{
|
||||
Name: "monitoring",
|
||||
Description: "System metrics, backend status, API and backend traces, backend process logs, and system information",
|
||||
Tags: []string{"monitoring"},
|
||||
Intro: "Includes real-time backend log streaming via WebSocket at /ws/backend-logs/:modelId.",
|
||||
},
|
||||
{
|
||||
Name: "mcp",
|
||||
Description: "Model Context Protocol — tool-augmented chat with MCP servers",
|
||||
Tags: []string{"mcp"},
|
||||
Intro: "The model's config must define MCP servers. The endpoint handles tool execution automatically.",
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Agent task and job management for CI/automation workflows",
|
||||
Tags: []string{"agent-jobs"},
|
||||
},
|
||||
{
|
||||
Name: "video",
|
||||
Description: "Video generation from text prompts",
|
||||
Tags: []string{"video"},
|
||||
},
|
||||
}
|
||||
|
||||
// swaggerState holds parsed swagger spec data, initialised once.
|
||||
type swaggerState struct {
|
||||
once sync.Once
|
||||
spec map[string]any // full parsed swagger JSON
|
||||
ready bool
|
||||
}
|
||||
|
||||
var swState swaggerState
|
||||
|
||||
func (s *swaggerState) init() {
|
||||
s.once.Do(func() {
|
||||
var spec map[string]any
|
||||
if err := json.Unmarshal(swagger.SwaggerJSON, &spec); err != nil {
|
||||
xlog.Error("failed to parse embedded swagger spec", "err", err)
|
||||
return
|
||||
}
|
||||
s.spec = spec
|
||||
s.ready = true
|
||||
})
|
||||
}
|
||||
|
||||
// filterSwaggerByTags returns a swagger fragment containing only paths whose
|
||||
// operations carry at least one of the given tags, plus the definitions they
|
||||
// reference.
|
||||
func filterSwaggerByTags(spec map[string]any, tags []string) map[string]any {
|
||||
tagSet := make(map[string]bool, len(tags))
|
||||
for _, t := range tags {
|
||||
tagSet[t] = true
|
||||
}
|
||||
|
||||
paths, _ := spec["paths"].(map[string]any)
|
||||
allDefs, _ := spec["definitions"].(map[string]any)
|
||||
|
||||
filteredPaths := make(map[string]any)
|
||||
for path, methods := range paths {
|
||||
methodMap, ok := methods.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
filteredMethods := make(map[string]any)
|
||||
for method, opRaw := range methodMap {
|
||||
op, ok := opRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
opTags, _ := op["tags"].([]any)
|
||||
for _, t := range opTags {
|
||||
if ts, ok := t.(string); ok && tagSet[ts] {
|
||||
filteredMethods[method] = op
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(filteredMethods) > 0 {
|
||||
filteredPaths[path] = filteredMethods
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all $ref definitions used by the filtered paths.
|
||||
neededDefs := make(map[string]bool)
|
||||
collectRefs(filteredPaths, neededDefs)
|
||||
|
||||
// Resolve nested refs from definitions themselves.
|
||||
changed := true
|
||||
for changed {
|
||||
changed = false
|
||||
for name := range neededDefs {
|
||||
if def, ok := allDefs[name]; ok {
|
||||
before := len(neededDefs)
|
||||
collectRefs(def, neededDefs)
|
||||
if len(neededDefs) > before {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filteredDefs := make(map[string]any)
|
||||
for name := range neededDefs {
|
||||
if def, ok := allDefs[name]; ok {
|
||||
filteredDefs[name] = def
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"paths": filteredPaths,
|
||||
}
|
||||
if len(filteredDefs) > 0 {
|
||||
result["definitions"] = filteredDefs
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// collectRefs walks a JSON structure and collects all $ref definition names.
|
||||
func collectRefs(v any, refs map[string]bool) {
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
if ref, ok := val["$ref"].(string); ok {
|
||||
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||
refs[ref[len(swaggerDefsPrefix):]] = true
|
||||
}
|
||||
}
|
||||
for _, child := range val {
|
||||
collectRefs(child, refs)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range val {
|
||||
collectRefs(child, refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swaggerToMarkdown renders a filtered swagger fragment into concise markdown.
|
||||
func swaggerToMarkdown(skillName, intro string, fragment map[string]any) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# ")
|
||||
b.WriteString(skillName)
|
||||
b.WriteString("\n")
|
||||
if intro != "" {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(intro)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
paths, _ := fragment["paths"].(map[string]any)
|
||||
defs, _ := fragment["definitions"].(map[string]any)
|
||||
|
||||
// Sort paths for stable output.
|
||||
sortedPaths := make([]string, 0, len(paths))
|
||||
for p := range paths {
|
||||
sortedPaths = append(sortedPaths, p)
|
||||
}
|
||||
sort.Strings(sortedPaths)
|
||||
|
||||
for _, path := range sortedPaths {
|
||||
methods, ok := paths[path].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sortedMethods := sortMethods(methods)
|
||||
for _, method := range sortedMethods {
|
||||
op, ok := methods[method].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
summary, _ := op["summary"].(string)
|
||||
b.WriteString(fmt.Sprintf("\n## %s %s\n", strings.ToUpper(method), path))
|
||||
if summary != "" {
|
||||
b.WriteString(summary)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Parameters
|
||||
params, _ := op["parameters"].([]any)
|
||||
bodyParams, nonBodyParams := splitParams(params)
|
||||
|
||||
if len(nonBodyParams) > 0 {
|
||||
b.WriteString("\n**Parameters:**\n")
|
||||
b.WriteString("| Name | In | Type | Required | Description |\n")
|
||||
b.WriteString("|------|----|------|----------|-------------|\n")
|
||||
for _, p := range nonBodyParams {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := pm["name"].(string)
|
||||
in, _ := pm["in"].(string)
|
||||
typ, _ := pm["type"].(string)
|
||||
req, _ := pm["required"].(bool)
|
||||
desc, _ := pm["description"].(string)
|
||||
b.WriteString(fmt.Sprintf("| %s | %s | %s | %v | %s |\n", name, in, typ, req, desc))
|
||||
}
|
||||
}
|
||||
|
||||
if len(bodyParams) > 0 {
|
||||
for _, p := range bodyParams {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
schema, _ := pm["schema"].(map[string]any)
|
||||
refName := resolveRefName(schema)
|
||||
if refName != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Request body** (`%s`):\n", refName))
|
||||
renderSchemaFields(&b, refName, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Responses
|
||||
responses, _ := op["responses"].(map[string]any)
|
||||
if len(responses) > 0 {
|
||||
sortedCodes := make([]string, 0, len(responses))
|
||||
for code := range responses {
|
||||
sortedCodes = append(sortedCodes, code)
|
||||
}
|
||||
sort.Strings(sortedCodes)
|
||||
for _, code := range sortedCodes {
|
||||
resp, ok := responses[code].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
desc, _ := resp["description"].(string)
|
||||
respSchema, _ := resp["schema"].(map[string]any)
|
||||
refName := resolveRefName(respSchema)
|
||||
if refName != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Response %s** (`%s`): %s\n", code, refName, desc))
|
||||
renderSchemaFields(&b, refName, defs)
|
||||
} else if desc != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Response %s**: %s\n", code, desc))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// sortMethods returns HTTP methods in a conventional order.
|
||||
func sortMethods(methods map[string]any) []string {
|
||||
order := map[string]int{"get": 0, "post": 1, "put": 2, "patch": 3, "delete": 4}
|
||||
keys := make([]string, 0, len(methods))
|
||||
for k := range methods {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
oi, oki := order[keys[i]]
|
||||
oj, okj := order[keys[j]]
|
||||
if !oki {
|
||||
oi = 99
|
||||
}
|
||||
if !okj {
|
||||
oj = 99
|
||||
}
|
||||
return oi < oj
|
||||
})
|
||||
return keys
|
||||
}
|
||||
|
||||
// splitParams separates body parameters from non-body parameters.
|
||||
func splitParams(params []any) (body, nonBody []any) {
|
||||
for _, p := range params {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if in, _ := pm["in"].(string); in == "body" {
|
||||
body = append(body, p)
|
||||
} else {
|
||||
nonBody = append(nonBody, p)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// resolveRefName extracts the definition name from a $ref or returns "".
|
||||
func resolveRefName(schema map[string]any) string {
|
||||
if schema == nil {
|
||||
return ""
|
||||
}
|
||||
if ref, ok := schema["$ref"].(string); ok {
|
||||
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||
return ref[len(swaggerDefsPrefix):]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// renderSchemaFields writes a markdown field table for a definition.
|
||||
func renderSchemaFields(b *strings.Builder, defName string, defs map[string]any) {
|
||||
if defs == nil {
|
||||
return
|
||||
}
|
||||
def, ok := defs[defName].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
props, ok := def["properties"].(map[string]any)
|
||||
if !ok || len(props) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Sort fields
|
||||
fields := make([]string, 0, len(props))
|
||||
for f := range props {
|
||||
fields = append(fields, f)
|
||||
}
|
||||
sort.Strings(fields)
|
||||
|
||||
b.WriteString("| Field | Type | Description |\n")
|
||||
b.WriteString("|-------|------|-------------|\n")
|
||||
for _, field := range fields {
|
||||
prop, ok := props[field].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typ := schemaTypeString(prop)
|
||||
desc, _ := prop["description"].(string)
|
||||
b.WriteString(fmt.Sprintf("| %s | %s | %s |\n", field, typ, desc))
|
||||
}
|
||||
}
|
||||
|
||||
// schemaTypeString returns a human-readable type string for a schema property.
|
||||
func schemaTypeString(prop map[string]any) string {
|
||||
if ref := resolveRefName(prop); ref != "" {
|
||||
return ref
|
||||
}
|
||||
typ, _ := prop["type"].(string)
|
||||
if typ == "array" {
|
||||
items, _ := prop["items"].(map[string]any)
|
||||
if items != nil {
|
||||
if ref := resolveRefName(items); ref != "" {
|
||||
return "[]" + ref
|
||||
}
|
||||
it, _ := items["type"].(string)
|
||||
if it != "" {
|
||||
return "[]" + it
|
||||
}
|
||||
}
|
||||
return "[]any"
|
||||
}
|
||||
if typ != "" {
|
||||
return typ
|
||||
}
|
||||
return "object"
|
||||
}
|
||||
|
||||
// APIInstructionResponse is the JSON response for a single instruction (?format=json).
|
||||
type APIInstructionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
SwaggerFragment map[string]any `json:"swagger_fragment,omitempty"`
|
||||
}
|
||||
|
||||
// ListAPIInstructionsEndpoint returns all instructions (compact list without guides).
|
||||
// @Summary List available API instruction areas
|
||||
// @Description Returns a compact list of instruction areas with descriptions and URLs for detailed guides
|
||||
// @Tags instructions
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]any "instructions list with hint"
|
||||
// @Router /api/instructions [get]
|
||||
func ListAPIInstructionsEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
type compactInstruction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
instructions := make([]compactInstruction, len(instructionDefs))
|
||||
for i, s := range instructionDefs {
|
||||
instructions[i] = compactInstruction{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Tags: s.Tags,
|
||||
URL: "/api/instructions/" + s.Name,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"instructions": instructions,
|
||||
"hint": "Fetch GET {url} for a markdown API guide. Add ?format=json for a raw OpenAPI fragment.",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetAPIInstructionEndpoint returns a single instruction by name.
|
||||
// @Summary Get an instruction's API guide or OpenAPI fragment
|
||||
// @Description Returns a markdown guide (default) or filtered OpenAPI fragment (format=json) for a named instruction
|
||||
// @Tags instructions
|
||||
// @Produce json
|
||||
// @Produce text/markdown
|
||||
// @Param name path string true "Instruction name (e.g. chat-inference, config-management)"
|
||||
// @Param format query string false "Response format: json for OpenAPI fragment, omit for markdown"
|
||||
// @Success 200 {object} APIInstructionResponse "instruction documentation"
|
||||
// @Failure 404 {object} map[string]string "instruction not found"
|
||||
// @Router /api/instructions/{name} [get]
|
||||
func GetAPIInstructionEndpoint() echo.HandlerFunc {
|
||||
byName := make(map[string]*instructionDef, len(instructionDefs))
|
||||
for i := range instructionDefs {
|
||||
byName[instructionDefs[i].Name] = &instructionDefs[i]
|
||||
}
|
||||
|
||||
return func(c echo.Context) error {
|
||||
name := c.Param("name")
|
||||
inst, ok := byName[name]
|
||||
if !ok {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "instruction not found: " + name})
|
||||
}
|
||||
|
||||
swState.init()
|
||||
if !swState.ready {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "swagger spec not available"})
|
||||
}
|
||||
|
||||
fragment := filterSwaggerByTags(swState.spec, inst.Tags)
|
||||
|
||||
format := c.QueryParam("format")
|
||||
if format == "json" {
|
||||
return c.JSON(http.StatusOK, APIInstructionResponse{
|
||||
Name: inst.Name,
|
||||
Description: inst.Description,
|
||||
Tags: inst.Tags,
|
||||
SwaggerFragment: fragment,
|
||||
})
|
||||
}
|
||||
|
||||
guide := swaggerToMarkdown(inst.Name, inst.Intro, fragment)
|
||||
return c.Blob(http.StatusOK, "text/markdown; charset=utf-8", []byte(guide))
|
||||
}
|
||||
}
|
||||
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("API Instructions Endpoints", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
app.GET("/api/instructions", ListAPIInstructionsEndpoint())
|
||||
app.GET("/api/instructions/:name", GetAPIInstructionEndpoint())
|
||||
})
|
||||
|
||||
Context("GET /api/instructions", func() {
|
||||
It("should return all instruction definitions", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(resp).To(HaveKey("hint"))
|
||||
Expect(resp).To(HaveKey("instructions"))
|
||||
|
||||
instructions, ok := resp["instructions"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(instructions).To(HaveLen(9))
|
||||
|
||||
// Verify each instruction has required fields and correct URL format
|
||||
for _, s := range instructions {
|
||||
inst, ok := s.(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(inst["name"]).NotTo(BeEmpty())
|
||||
Expect(inst["description"]).NotTo(BeEmpty())
|
||||
Expect(inst["tags"]).NotTo(BeNil())
|
||||
Expect(inst["url"]).To(HavePrefix("/api/instructions/"))
|
||||
Expect(inst["url"]).To(Equal("/api/instructions/" + inst["name"].(string)))
|
||||
}
|
||||
})
|
||||
|
||||
It("should include known instruction names", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
instructions := resp["instructions"].([]any)
|
||||
names := make([]string, len(instructions))
|
||||
for i, s := range instructions {
|
||||
names[i] = s.(map[string]any)["name"].(string)
|
||||
}
|
||||
|
||||
Expect(names).To(ContainElements(
|
||||
"chat-inference",
|
||||
"config-management",
|
||||
"model-management",
|
||||
"monitoring",
|
||||
"agents",
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
Context("GET /api/instructions/:name", func() {
|
||||
It("should return 404 for unknown instruction", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["error"]).To(ContainSubstring("instruction not found"))
|
||||
})
|
||||
|
||||
It("should return markdown by default", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Header().Get("Content-Type")).To(ContainSubstring("text/markdown"))
|
||||
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
md := string(body)
|
||||
|
||||
Expect(md).To(HavePrefix("# chat-inference"))
|
||||
// Should contain at least one endpoint heading
|
||||
Expect(md).To(MatchRegexp(`## (GET|POST|PUT|PATCH|DELETE) /`))
|
||||
})
|
||||
|
||||
It("should include intro text for instructions that have one", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
// chat-inference has an intro about streaming
|
||||
Expect(string(body)).To(ContainSubstring("stream"))
|
||||
})
|
||||
|
||||
It("should return JSON fragment when format=json", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["name"]).To(Equal("chat-inference"))
|
||||
Expect(resp["tags"]).To(ContainElements("inference", "embeddings"))
|
||||
|
||||
fragment, ok := resp["swagger_fragment"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(fragment).To(HaveKey("paths"))
|
||||
|
||||
paths, ok := fragment["paths"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(paths).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("should include referenced definitions in JSON fragment", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
fragment := resp["swagger_fragment"].(map[string]any)
|
||||
Expect(fragment).To(HaveKey("definitions"))
|
||||
|
||||
defs, ok := fragment["definitions"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(defs).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("should only include paths matching the instruction tags in JSON fragment", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/config-management?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
fragment := resp["swagger_fragment"].(map[string]any)
|
||||
paths := fragment["paths"].(map[string]any)
|
||||
Expect(paths).NotTo(BeEmpty())
|
||||
|
||||
// Every operation in every path should have the "config" tag
|
||||
for _, methods := range paths {
|
||||
methodMap := methods.(map[string]any)
|
||||
for _, opRaw := range methodMap {
|
||||
op := opRaw.(map[string]any)
|
||||
tags, _ := op["tags"].([]any)
|
||||
tagStrs := make([]string, len(tags))
|
||||
for i, t := range tags {
|
||||
tagStrs[i] = t.(string)
|
||||
}
|
||||
Expect(tagStrs).To(ContainElement("config"))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("should produce stable output across calls", func() {
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec1, req1)
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec2, req2)
|
||||
|
||||
body1, _ := io.ReadAll(rec1.Body)
|
||||
body2, _ := io.ReadAll(rec2.Body)
|
||||
Expect(string(body1)).To(Equal(string(body2)))
|
||||
})
|
||||
|
||||
It("should return markdown for every defined instruction", func() {
|
||||
// First get the list
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
listRec := httptest.NewRecorder()
|
||||
app.ServeHTTP(listRec, listReq)
|
||||
|
||||
var listResp map[string]any
|
||||
Expect(json.Unmarshal(listRec.Body.Bytes(), &listResp)).To(Succeed())
|
||||
|
||||
instructions := listResp["instructions"].([]any)
|
||||
for _, s := range instructions {
|
||||
name := s.(map[string]any)["name"].(string)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/"+name, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"instruction %q should return 200", name)
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
Expect(strings.TrimSpace(string(body))).NotTo(BeEmpty(),
|
||||
"instruction %q should return non-empty markdown", name)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -15,28 +15,37 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// UpgradeInfoProvider is an interface for querying cached backend upgrade information.
|
||||
type UpgradeInfoProvider interface {
|
||||
GetAvailableUpgrades() map[string]gallery.UpgradeInfo
|
||||
TriggerCheck()
|
||||
}
|
||||
|
||||
type BackendEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendSystemPath string
|
||||
backendApplier *galleryop.GalleryService
|
||||
upgradeChecker UpgradeInfoProvider
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService {
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService, upgradeChecker UpgradeInfoProvider) BackendEndpointService {
|
||||
return BackendEndpointService{
|
||||
galleries: galleries,
|
||||
backendPath: systemState.Backend.BackendsPath,
|
||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||
backendApplier: backendApplier,
|
||||
upgradeChecker: upgradeChecker,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Tags backends
|
||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -51,6 +60,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -61,6 +71,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||
// @Summary Install backends to LocalAI.
|
||||
// @Tags backends
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
@@ -88,6 +99,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
|
||||
// DeleteBackendEndpoint lets delete backends from a LocalAI instance
|
||||
// @Summary delete backends from LocalAI.
|
||||
// @Tags backends
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
@@ -112,6 +124,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ListBackendsEndpoint list the available backends configured in LocalAI
|
||||
// @Summary List all Backends
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||
@@ -126,6 +139,7 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
@@ -140,8 +154,65 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpgradesEndpoint returns the cached backend upgrade information
|
||||
// @Summary Get available backend upgrades
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||
// @Router /backends/upgrades [get]
|
||||
func (mgs *BackendEndpointService) GetUpgradesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if mgs.upgradeChecker == nil {
|
||||
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||
}
|
||||
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUpgradesEndpoint forces an immediate upgrade check
|
||||
// @Summary Force backend upgrade check
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||
// @Router /backends/upgrades/check [post]
|
||||
func (mgs *BackendEndpointService) CheckUpgradesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if mgs.upgradeChecker == nil {
|
||||
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||
}
|
||||
mgs.upgradeChecker.TriggerCheck()
|
||||
// Return current cached results (the triggered check runs async)
|
||||
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||
}
|
||||
}
|
||||
|
||||
// UpgradeBackendEndpoint triggers an upgrade for a specific backend
|
||||
// @Summary Upgrade a backend
|
||||
// @Tags backends
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/upgrade/{name} [post]
|
||||
func (mgs *BackendEndpointService) UpgradeBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: backendName,
|
||||
Galleries: mgs.galleries,
|
||||
Upgrade: true,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
|
||||
// @Summary List all available Backends
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
|
||||
179
core/http/endpoints/localai/backend_logs.go
Normal file
179
core/http/endpoints/localai/backend_logs.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var backendLogsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // no origin header = same-origin or non-browser
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == r.Host
|
||||
},
|
||||
}
|
||||
|
||||
// backendLogsConn wraps a websocket connection with a mutex for safe concurrent writes
|
||||
type backendLogsConn struct {
|
||||
*websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writeJSON(v any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal error: %w", err)
|
||||
}
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writePing() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
return c.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
// ListBackendLogsEndpoint returns model IDs that have log buffers
|
||||
// @Summary List models with backend logs
|
||||
// @Description Returns a sorted list of model IDs that have captured backend process output
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Success 200 {array} string "Model IDs with logs"
|
||||
// @Router /api/backend-logs [get]
|
||||
func ListBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, ml.BackendLogs().ListModels())
|
||||
}
|
||||
}
|
||||
|
||||
// GetBackendLogsEndpoint returns log lines for a specific model
|
||||
// @Summary Get backend logs for a model
|
||||
// @Description Returns all captured log lines (stdout/stderr) for the specified model's backend process
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Success 200 {array} model.BackendLogLine "Log lines"
|
||||
// @Router /api/backend-logs/{modelId} [get]
|
||||
func GetBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
return c.JSON(200, ml.BackendLogs().GetLines(modelID))
|
||||
}
|
||||
}
|
||||
|
||||
// ClearBackendLogsEndpoint clears log lines for a specific model
|
||||
// @Summary Clear backend logs for a model
|
||||
// @Description Removes all captured log lines for the specified model's backend process
|
||||
// @Tags monitoring
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Success 204 "Logs cleared"
|
||||
// @Router /api/backend-logs/{modelId}/clear [post]
|
||||
func ClearBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ml.BackendLogs().Clear(c.Param("modelId"))
|
||||
return c.NoContent(204)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendLogsWebSocketEndpoint streams backend logs in real-time over WebSocket
|
||||
// @Summary Stream backend logs via WebSocket
|
||||
// @Description Opens a WebSocket connection for real-time backend log streaming. Sends an initial batch of existing lines (type "initial"), then streams new lines as they appear (type "line"). Supports ping/pong keepalive.
|
||||
// @Tags monitoring
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Router /ws/backend-logs/{modelId} [get]
|
||||
func BackendLogsWebSocketEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
|
||||
ws, err := backendLogsUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
ws.SetReadLimit(4096)
|
||||
|
||||
// Set up ping/pong for keepalive
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
conn := &backendLogsConn{Conn: ws}
|
||||
|
||||
// Send existing lines as initial batch
|
||||
existingLines := ml.BackendLogs().GetLines(modelID)
|
||||
initialMsg := map[string]any{
|
||||
"type": "initial",
|
||||
"lines": existingLines,
|
||||
}
|
||||
if err := conn.writeJSON(initialMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs initial write failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe to new lines
|
||||
lineCh, unsubscribe := ml.BackendLogs().Subscribe(modelID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Handle close from client side
|
||||
closeCh := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
close(closeCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping ticker for keepalive
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
// Forward new lines to WebSocket
|
||||
for {
|
||||
select {
|
||||
case line, ok := <-lineCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
lineMsg := map[string]any{
|
||||
"type": "line",
|
||||
"line": line,
|
||||
}
|
||||
if err := conn.writeJSON(lineMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs write error", "error", err)
|
||||
return nil
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := conn.writePing(); err != nil {
|
||||
return nil
|
||||
}
|
||||
case <-closeCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Backend Logs Endpoints", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
modelLoader *model.ModelLoader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "backend-logs-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
modelsPath := filepath.Join(tempDir, "models")
|
||||
Expect(os.MkdirAll(modelsPath, 0750)).To(Succeed())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(modelsPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
|
||||
app = echo.New()
|
||||
app.GET("/api/backend-logs", ListBackendLogsEndpoint(modelLoader))
|
||||
app.GET("/api/backend-logs/:modelId", GetBackendLogsEndpoint(modelLoader))
|
||||
app.POST("/api/backend-logs/:modelId/clear", ClearBackendLogsEndpoint(modelLoader))
|
||||
app.GET("/ws/backend-logs/:modelId", BackendLogsWebSocketEndpoint(modelLoader))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Context("REST endpoints", func() {
|
||||
It("should return empty list of models with logs", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var models []string
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||
Expect(models).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should list models that have logs", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var models []string
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||
Expect(models).To(ContainElement("my-model"))
|
||||
})
|
||||
|
||||
It("should return log lines for a model", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "line one")
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stderr", "line two")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var lines []model.BackendLogLine
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||
Expect(lines).To(HaveLen(2))
|
||||
Expect(lines[0].Text).To(Equal("line one"))
|
||||
Expect(lines[0].Stream).To(Equal("stdout"))
|
||||
Expect(lines[1].Text).To(Equal("line two"))
|
||||
Expect(lines[1].Stream).To(Equal("stderr"))
|
||||
})
|
||||
|
||||
It("should return empty log lines for unknown model", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/unknown-model", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("should clear logs for a model", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/backend-logs/my-model/clear", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||
|
||||
// Verify logs are cleared
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var lines []model.BackendLogLine
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||
Expect(lines).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WebSocket endpoint", func() {
|
||||
It("should send initial lines and stream new lines", func() {
|
||||
// Seed some existing lines before connecting
|
||||
modelLoader.BackendLogs().AppendLine("ws-model", "stdout", "existing line")
|
||||
|
||||
// Start a real HTTP server for WebSocket
|
||||
srv := httptest.NewServer(app)
|
||||
defer srv.Close()
|
||||
|
||||
// Dial the WebSocket
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/ws-model"
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
// Read the initial message
|
||||
var initialMsg map[string]any
|
||||
err = conn.ReadJSON(&initialMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||
|
||||
initialLines, ok := initialMsg["lines"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(initialLines).To(HaveLen(1))
|
||||
|
||||
firstLine := initialLines[0].(map[string]any)
|
||||
Expect(firstLine["text"]).To(Equal("existing line"))
|
||||
|
||||
// Now append a new line and verify it streams through
|
||||
modelLoader.BackendLogs().AppendLine("ws-model", "stderr", "streamed line")
|
||||
|
||||
var lineMsg map[string]any
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
err = conn.ReadJSON(&lineMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(lineMsg["type"]).To(Equal("line"))
|
||||
|
||||
lineData, ok := lineMsg["line"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(lineData["text"]).To(Equal("streamed line"))
|
||||
Expect(lineData["stream"]).To(Equal("stderr"))
|
||||
})
|
||||
|
||||
It("should handle connection close gracefully", func() {
|
||||
srv := httptest.NewServer(app)
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/close-model"
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Read initial message
|
||||
var initialMsg map[string]any
|
||||
err = conn.ReadJSON(&initialMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||
|
||||
// Close the connection from client side
|
||||
conn.Close()
|
||||
|
||||
// Give the server goroutine time to detect the close
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// No panic or hang — the test passing is the assertion
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Tags monitoring
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
@@ -29,7 +30,8 @@ func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFu
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Summary Backend shutdown endpoint
|
||||
// @Tags monitoring
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
||||
|
||||
244
core/http/endpoints/localai/config_meta.go
Normal file
244
core/http/endpoints/localai/config_meta.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigMetadataEndpoint returns field metadata for config fields.
|
||||
// Without ?section, returns just the section index (lightweight).
|
||||
// With ?section=<id>, returns fields for that section only.
|
||||
// With ?section=all, returns all fields grouped by section.
|
||||
// @Summary List model configuration field metadata
|
||||
// @Description Returns config field metadata. Use ?section=<id> to filter by section, or omit for a section index.
|
||||
// @Tags config
|
||||
// @Produce json
|
||||
// @Param section query string false "Section ID to filter (e.g. 'general', 'llm', 'parameters') or 'all' for everything"
|
||||
// @Success 200 {object} map[string]any "Section index or filtered field metadata"
|
||||
// @Router /api/models/config-metadata [get]
|
||||
func ConfigMetadataEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
sectionParam := c.QueryParam("section")
|
||||
|
||||
// No section param: return lightweight section index.
|
||||
if sectionParam == "" {
|
||||
sections := meta.DefaultSections()
|
||||
type sectionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
index := make([]sectionInfo, len(sections))
|
||||
for i, s := range sections {
|
||||
index[i] = sectionInfo{
|
||||
ID: s.ID,
|
||||
Label: s.Label,
|
||||
URL: "/api/models/config-metadata?section=" + s.ID,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"hint": "Fetch a section URL to see its fields. Use ?section=all for everything.",
|
||||
"sections": index,
|
||||
})
|
||||
}
|
||||
|
||||
md := meta.BuildConfigMetadata(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
// section=all: return everything.
|
||||
if sectionParam == "all" {
|
||||
return c.JSON(http.StatusOK, md)
|
||||
}
|
||||
|
||||
// Filter to requested section.
|
||||
var filtered []meta.FieldMeta
|
||||
for _, f := range md.Fields {
|
||||
if f.Section == sectionParam {
|
||||
filtered = append(filtered, f)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown section: " + sectionParam})
|
||||
}
|
||||
return c.JSON(http.StatusOK, filtered)
|
||||
}
|
||||
}
|
||||
|
||||
// AutocompleteEndpoint handles dynamic autocomplete lookups for config fields.
|
||||
// Static option lists (quantizations, cache types, diffusers pipelines/schedulers)
|
||||
// are embedded directly in the field metadata Options; only truly dynamic values
|
||||
// that require runtime lookup are served here.
|
||||
// @Summary Get dynamic autocomplete values for a config field
|
||||
// @Description Returns runtime-resolved values for dynamic providers (backends, models)
|
||||
// @Tags config
|
||||
// @Produce json
|
||||
// @Param provider path string true "Provider name (backends, models, models:chat, models:tts, models:transcript, models:vad)"
|
||||
// @Success 200 {object} map[string]any "values array"
|
||||
// @Router /api/models/config-metadata/autocomplete/{provider} [get]
|
||||
func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
provider := c.Param("provider")
|
||||
var values []string
|
||||
|
||||
switch {
|
||||
case provider == meta.ProviderBackends:
|
||||
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
|
||||
if err == nil {
|
||||
for name := range installedBackends {
|
||||
values = append(values, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(values)
|
||||
|
||||
case provider == meta.ProviderModels:
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
for _, cfg := range modelConfigs {
|
||||
values = append(values, cfg.Name)
|
||||
}
|
||||
modelsWithoutConfig, _ := galleryop.ListModels(cl, ml, config.NoFilterFn, galleryop.LOOSE_ONLY)
|
||||
values = append(values, modelsWithoutConfig...)
|
||||
sort.Strings(values)
|
||||
|
||||
case strings.HasPrefix(provider, "models:"):
|
||||
capability := strings.TrimPrefix(provider, "models:")
|
||||
var filterFn config.ModelConfigFilterFn
|
||||
switch capability {
|
||||
case "chat":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT)
|
||||
case "tts":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS)
|
||||
case "vad":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD)
|
||||
case "transcript":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||
default:
|
||||
filterFn = config.NoFilterFn
|
||||
}
|
||||
filteredConfigs := cl.GetModelConfigsByFilter(filterFn)
|
||||
for _, cfg := range filteredConfigs {
|
||||
values = append(values, cfg.Name)
|
||||
}
|
||||
sort.Strings(values)
|
||||
|
||||
default:
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown provider: " + provider})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{"values": values})
|
||||
}
|
||||
}
|
||||
|
||||
// PatchConfigEndpoint handles PATCH requests to partially update a model config
|
||||
// using nested JSON merge.
|
||||
// @Summary Partially update a model configuration
|
||||
// @Description Deep-merges the JSON patch body into the existing model config
|
||||
// @Tags config
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} map[string]any "success message"
|
||||
// @Router /api/models/config-json/{name} [patch]
|
||||
func PatchConfigEndpoint(cl *config.ModelConfigLoader, _ *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "model name is required"})
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "model configuration not found"})
|
||||
}
|
||||
|
||||
patchBody, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil || len(patchBody) == 0 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "request body is empty or unreadable"})
|
||||
}
|
||||
|
||||
var patchMap map[string]any
|
||||
if err := json.Unmarshal(patchBody, &patchMap); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid JSON: " + err.Error()})
|
||||
}
|
||||
|
||||
// Read the raw YAML from disk rather than serializing the in-memory config.
|
||||
// The in-memory config has SetDefaults() applied, which would persist
|
||||
// runtime-only defaults (top_p, temperature, mirostat, etc.) to the file.
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, map[string]any{"error": "config path not trusted: " + err.Error()})
|
||||
}
|
||||
|
||||
diskYAML, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to read config file: " + err.Error()})
|
||||
}
|
||||
|
||||
var existingMap map[string]any
|
||||
if err := yaml.Unmarshal(diskYAML, &existingMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to parse existing config: " + err.Error()})
|
||||
}
|
||||
if existingMap == nil {
|
||||
existingMap = map[string]any{}
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&existingMap, patchMap, mergo.WithOverride); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to merge configs: " + err.Error()})
|
||||
}
|
||||
|
||||
// Marshal once and reuse for both validation and writing
|
||||
yamlData, err := yaml.Marshal(existingMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal YAML"})
|
||||
}
|
||||
|
||||
var updatedConfig config.ModelConfig
|
||||
if err := yaml.Unmarshal(yamlData, &updatedConfig); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "merged config is invalid: " + err.Error()})
|
||||
}
|
||||
|
||||
if valid, err := updatedConfig.Validate(); !valid {
|
||||
errMsg := "validation failed"
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": errMsg})
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to write config file"})
|
||||
}
|
||||
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to reload configs: " + err.Error()})
|
||||
}
|
||||
|
||||
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
xlog.Warn("Failed to preload after PATCH", "error", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Model '%s' updated successfully", modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
292
core/http/endpoints/localai/config_meta_test.go
Normal file
292
core/http/endpoints/localai/config_meta_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Config Metadata Endpoints", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
configLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "config-meta-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
appConfig = config.NewApplicationConfig(
|
||||
config.WithSystemState(systemState),
|
||||
)
|
||||
configLoader = config.NewModelConfigLoader(tempDir)
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
|
||||
app = echo.New()
|
||||
app.GET("/api/models/config-metadata", ConfigMetadataEndpoint())
|
||||
app.GET("/api/models/config-metadata/autocomplete/:provider", AutocompleteEndpoint(configLoader, modelLoader, appConfig))
|
||||
app.PATCH("/api/models/config-json/:name", PatchConfigEndpoint(configLoader, modelLoader, appConfig))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Context("GET /api/models/config-metadata", func() {
|
||||
It("should return section index when no section param", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("hint"))
|
||||
Expect(resp).To(HaveKey("sections"))
|
||||
|
||||
sections, ok := resp["sections"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(sections).NotTo(BeEmpty())
|
||||
|
||||
// Verify known section IDs are present
|
||||
ids := make([]string, len(sections))
|
||||
for i, s := range sections {
|
||||
sec := s.(map[string]any)
|
||||
Expect(sec).To(HaveKey("id"))
|
||||
Expect(sec).To(HaveKey("label"))
|
||||
Expect(sec).To(HaveKey("url"))
|
||||
ids[i] = sec["id"].(string)
|
||||
}
|
||||
Expect(ids).To(ContainElements("general", "parameters"))
|
||||
})
|
||||
|
||||
It("should return all fields when section=all", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=all", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("fields"))
|
||||
|
||||
fields, ok := resp["fields"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(fields)).To(BeNumerically(">=", 80))
|
||||
})
|
||||
|
||||
It("should filter by section", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=general", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var fields []map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &fields)).To(Succeed())
|
||||
Expect(fields).NotTo(BeEmpty())
|
||||
|
||||
for _, f := range fields {
|
||||
Expect(f["section"]).To(Equal("general"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should return 404 for unknown section", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Context("GET /api/models/config-metadata/autocomplete/:provider", func() {
|
||||
It("should return values for backends provider", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/backends", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("values"))
|
||||
})
|
||||
|
||||
It("should return model names for models provider", func() {
|
||||
// Seed a model config
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/models", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
values, ok := resp["values"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(values).To(ContainElement("test-model"))
|
||||
})
|
||||
|
||||
It("should return 404 for unknown provider", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/unknown", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PATCH /api/models/config-json/:name", func() {
|
||||
It("should return 404 for nonexistent model", func() {
|
||||
body := bytes.NewBufferString(`{"backend": "bar"}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/nonexistent", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
|
||||
It("should return 400 for empty body", func() {
|
||||
// Seed a model config
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("should return 400 for invalid JSON", func() {
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`not json`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("should merge a field update and persist to disk", func() {
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
configPath := filepath.Join(tempDir, "test-model.yaml")
|
||||
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`{"backend": "vllm"}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["success"]).To(BeTrue())
|
||||
|
||||
// Verify the reloaded config has the updated value
|
||||
updatedConfig, exists := configLoader.GetModelConfig("test-model")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(updatedConfig.Backend).To(Equal("vllm"))
|
||||
|
||||
// Verify the file on disk was updated
|
||||
data, err := os.ReadFile(configPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(data)).To(ContainSubstring("vllm"))
|
||||
})
|
||||
|
||||
It("should not persist runtime defaults (SetDefaults values) to disk", func() {
|
||||
// Create a minimal pipeline config - no sampling params
|
||||
seedConfig := `name: gpt-realtime
|
||||
pipeline:
|
||||
vad: silero-vad
|
||||
transcription: whisper-base
|
||||
llm: llama3
|
||||
tts: piper
|
||||
`
|
||||
configPath := filepath.Join(tempDir, "gpt-realtime.yaml")
|
||||
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
// PATCH with a small change to the pipeline
|
||||
body := bytes.NewBufferString(`{"pipeline": {"tts": "vibevoice"}}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/gpt-realtime", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Read the file from disk and verify no spurious defaults leaked
|
||||
data, err := os.ReadFile(configPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
fileContent := string(data)
|
||||
|
||||
// The patched value should be present
|
||||
Expect(fileContent).To(ContainSubstring("vibevoice"))
|
||||
|
||||
// Runtime-only defaults from SetDefaults() should NOT be in the file
|
||||
Expect(fileContent).NotTo(ContainSubstring("top_p"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("top_k"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("temperature"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mirostat"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mmap"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mmlock"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("threads"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("low_vram"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("embeddings"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("f16"))
|
||||
|
||||
// Original fields should still be present
|
||||
Expect(fileContent).To(ContainSubstring("gpt-realtime"))
|
||||
Expect(fileContent).To(ContainSubstring("silero-vad"))
|
||||
Expect(fileContent).To(ContainSubstring("whisper-base"))
|
||||
Expect(fileContent).To(ContainSubstring("llama3"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,8 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
|
||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
||||
// @Summary Detects objects in the input image.
|
||||
// @Tags detection
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
@@ -36,7 +39,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
||||
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -45,12 +48,18 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
Detections: make([]schema.Detection, len(res.Detections)),
|
||||
}
|
||||
for i, detection := range res.Detections {
|
||||
var mask string
|
||||
if len(detection.Mask) > 0 {
|
||||
mask = base64.StdEncoding.EncodeToString(detection.Mask)
|
||||
}
|
||||
response.Detections[i] = schema.Detection{
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
Confidence: detection.Confidence,
|
||||
Mask: mask,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGaller
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Tags models
|
||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -54,6 +55,7 @@ func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Tags models
|
||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -64,6 +66,7 @@ func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Tags models
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
@@ -93,6 +96,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Tags models
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
@@ -118,7 +122,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Tags models
|
||||
// @Success 200 {object} []gallery.Metadata "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -149,6 +154,7 @@ func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Tags models
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
|
||||
//
|
||||
// @Summary Get TokenMetrics for Active Slot.
|
||||
// @Tags tokenize
|
||||
// @Accept json
|
||||
// @Produce audio/x-wav
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
|
||||
@@ -119,48 +119,20 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
// Detect format once and reuse for both typed and map parsing
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
var modelConfig config.ModelConfig
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
isJSON := strings.Contains(contentType, "application/json") ||
|
||||
(!strings.Contains(contentType, "yaml") && len(trimmed) > 0 && trimmed[0] == '{')
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
var modelConfig config.ModelConfig
|
||||
if isJSON {
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse JSON: " + err.Error()})
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse YAML: " + err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +145,9 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
modelConfig.SetDefaults(appConfig.ToConfigLoaderOptions()...)
|
||||
|
||||
// Validate the configuration
|
||||
// Validate without calling SetDefaults() — runtime defaults should not
|
||||
// be persisted to disk. SetDefaults() is called when loading configs
|
||||
// for inference via LoadModelConfigsFromPath().
|
||||
if valid, _ := modelConfig.Validate(); !valid {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
@@ -195,8 +166,21 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
yamlData, err := yaml.Marshal(&modelConfig)
|
||||
// Write only the user-provided fields to disk by parsing the original
|
||||
// body into a map (not the typed struct, which includes Go zero values).
|
||||
var bodyMap map[string]any
|
||||
if isJSON {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
} else {
|
||||
_ = yaml.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
|
||||
var yamlData []byte
|
||||
if bodyMap != nil {
|
||||
yamlData, err = yaml.Marshal(bodyMap)
|
||||
} else {
|
||||
yamlData, err = yaml.Marshal(&modelConfig)
|
||||
}
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
|
||||
@@ -53,6 +53,7 @@ type MCPErrorEvent struct {
|
||||
// which handles MCP tool injection and server-side execution.
|
||||
// Both streaming and non-streaming modes use standard OpenAI response format.
|
||||
// @Summary MCP chat completions with automatic tool execution
|
||||
// @Tags mcp
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/mcp/chat/completions [post]
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Tags monitoring
|
||||
// @Produce text/plain
|
||||
// @Success 200 {string} string "Prometheus metrics"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||
return echo.WrapHandler(promhttp.Handler())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// ShowP2PNodes returns the P2P Nodes
|
||||
// @Summary Returns available P2P nodes
|
||||
// @Tags p2p
|
||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||
// @Router /api/p2p [get]
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
@@ -24,6 +25,7 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
// ShowP2PToken returns the P2P token
|
||||
// @Summary Show the P2P token
|
||||
// @Tags p2p
|
||||
// @Success 200 {string} string "Response"
|
||||
// @Router /api/p2p/token [get]
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
144
core/http/endpoints/localai/pin_model.go
Normal file
144
core/http/endpoints/localai/pin_model.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TogglePinnedModelEndpoint handles pinning or unpinning a model.
|
||||
// Pinned models are excluded from idle unloading, LRU eviction, and memory-pressure eviction.
|
||||
//
|
||||
// @Summary Toggle model pinned status
|
||||
// @Description Pin or unpin a model. Pinned models stay loaded and are excluded from automatic eviction.
|
||||
// @Tags config
|
||||
// @Param name path string true "Model name"
|
||||
// @Param action path string true "Action: 'pin' or 'unpin'"
|
||||
// @Success 200 {object} ModelResponse
|
||||
// @Failure 400 {object} ModelResponse
|
||||
// @Failure 404 {object} ModelResponse
|
||||
// @Failure 500 {object} ModelResponse
|
||||
// @Router /api/models/toggle-pinned/{name}/{action} [put]
|
||||
func TogglePinnedModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, syncPinnedFn func()) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
})
|
||||
}
|
||||
|
||||
action := c.Param("action")
|
||||
if action != "pin" && action != "unpin" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Action must be 'pin' or 'unpin'",
|
||||
})
|
||||
}
|
||||
|
||||
// Get existing model config
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Get the config file path
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if configPath == "" {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the path is trusted
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Read the existing config file
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Parse the YAML config as a generic map to preserve all fields
|
||||
var configMap map[string]interface{}
|
||||
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update the pinned field
|
||||
pinned := action == "pin"
|
||||
if pinned {
|
||||
configMap["pinned"] = true
|
||||
} else {
|
||||
// Remove the pinned key entirely when unpinning (clean YAML)
|
||||
delete(configMap, "pinned")
|
||||
}
|
||||
|
||||
// Marshal back to YAML
|
||||
updatedData, err := yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to serialize configuration: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Write updated config back to file
|
||||
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Reload model configurations from disk
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Sync pinned models to the watchdog
|
||||
if syncPinnedFn != nil {
|
||||
syncPinnedFn()
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Model '%s' has been %sned successfully.", modelName, action)
|
||||
if pinned {
|
||||
msg += " The model will be excluded from automatic eviction."
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ModelResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Filename: configPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// SystemInformations returns the system informations
|
||||
// @Summary Show the LocalAI instance information
|
||||
// @Tags monitoring
|
||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||
// @Router /system [get]
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
148
core/http/endpoints/localai/toggle_model.go
Normal file
148
core/http/endpoints/localai/toggle_model.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ToggleModelEndpoint handles enabling or disabling a model from being loaded on demand.
|
||||
// When disabled, the model remains in the collection but will not be loaded when requested.
|
||||
//
|
||||
// @Summary Toggle model enabled/disabled status
|
||||
// @Description Enable or disable a model from being loaded on demand. Disabled models remain installed but cannot be loaded.
|
||||
// @Tags config
|
||||
// @Param name path string true "Model name"
|
||||
// @Param action path string true "Action: 'enable' or 'disable'"
|
||||
// @Success 200 {object} ModelResponse
|
||||
// @Failure 400 {object} ModelResponse
|
||||
// @Failure 404 {object} ModelResponse
|
||||
// @Failure 500 {object} ModelResponse
|
||||
// @Router /api/models/{name}/{action} [put]
|
||||
func ToggleStateModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
})
|
||||
}
|
||||
|
||||
action := c.Param("action")
|
||||
if action != "enable" && action != "disable" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Action must be 'enable' or 'disable'",
|
||||
})
|
||||
}
|
||||
|
||||
// Get existing model config
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Get the config file path
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if configPath == "" {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the path is trusted
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Read the existing config file
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Parse the YAML config as a generic map to preserve all fields
|
||||
var configMap map[string]interface{}
|
||||
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update the disabled field
|
||||
disabled := action == "disable"
|
||||
if disabled {
|
||||
configMap["disabled"] = true
|
||||
} else {
|
||||
// Remove the disabled key entirely when enabling (clean YAML)
|
||||
delete(configMap, "disabled")
|
||||
}
|
||||
|
||||
// Marshal back to YAML
|
||||
updatedData, err := yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to serialize configuration: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Write updated config back to file
|
||||
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Reload model configurations from disk
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// If disabling, also shutdown the model if it's currently running
|
||||
if disabled {
|
||||
if err := ml.ShutdownModel(modelName); err != nil {
|
||||
// Log but don't fail - the config was saved successfully
|
||||
fmt.Printf("Warning: Failed to shutdown model '%s' during disable: %v\n", modelName, err)
|
||||
}
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Model '%s' has been %sd successfully.", modelName, action)
|
||||
if disabled {
|
||||
msg += " The model will not be loaded on demand until re-enabled."
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ModelResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Filename: configPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user