mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 00:26:34 -04:00
Compare commits
79 Commits
v4.1.0
...
feat/backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fe87cb0d5 | ||
|
|
6dd37a95c4 | ||
|
|
ee00a10836 | ||
|
|
948f3bfaa4 | ||
|
|
1e083cd870 | ||
|
|
b19e60d03a | ||
|
|
4d463e9f0d | ||
|
|
ae4ae5f425 | ||
|
|
7c1865b307 | ||
|
|
62a674ce12 | ||
|
|
c39213443b | ||
|
|
606f462da4 | ||
|
|
5c35e85fe2 | ||
|
|
062e0d0d00 | ||
|
|
d4cd6c284f | ||
|
|
3bb8b65d31 | ||
|
|
9748a1cbc6 | ||
|
|
6bc76dda6d | ||
|
|
e1a6010874 | ||
|
|
706cf5d43c | ||
|
|
13a6ed709c | ||
|
|
85be4ff03c | ||
|
|
b0d9ce4905 | ||
|
|
7081b54c09 | ||
|
|
2b05420f95 | ||
|
|
b64347b6aa | ||
|
|
e00ce981f0 | ||
|
|
285f7d4340 | ||
|
|
ea6e850809 | ||
|
|
b7247fc148 | ||
|
|
39c6b3ed66 | ||
|
|
0e9d1a6588 | ||
|
|
510d6759fe | ||
|
|
154fa000d3 | ||
|
|
0526e60f8d | ||
|
|
db600fb5b2 | ||
|
|
9ac1bdc587 | ||
|
|
fdc9f7bf35 | ||
|
|
8e59346091 | ||
|
|
e6e4e19633 | ||
|
|
505c417fa7 | ||
|
|
17215f6fbc | ||
|
|
bccaba1f66 | ||
|
|
0f9d516a6c | ||
|
|
33b124c6f1 | ||
|
|
6b8007e88e | ||
|
|
b3837c2078 | ||
|
|
92f99b1ec3 | ||
|
|
ad232fdb1a | ||
|
|
11637b5a1b | ||
|
|
0dda4fe6f0 | ||
|
|
773489eeb1 | ||
|
|
06fbe48b3f | ||
|
|
232e324a68 | ||
|
|
39c954764c | ||
|
|
9b7d5513fc | ||
|
|
84cd8c0e7f | ||
|
|
d990f2790c | ||
|
|
53deeb1107 | ||
|
|
c5a840f6af | ||
|
|
6d9d77d590 | ||
|
|
6f304d1201 | ||
|
|
557d0f0f04 | ||
|
|
b7e3589875 | ||
|
|
716ddd697b | ||
|
|
223deb908d | ||
|
|
9f8821bba8 | ||
|
|
84e51b68ef | ||
|
|
7962dd16f7 | ||
|
|
a1466b305a | ||
|
|
57c0026715 | ||
|
|
1ed6b9e5ed | ||
|
|
e4ee74354f | ||
|
|
8577bdcebc | ||
|
|
0d489c7a0d | ||
|
|
11dc54bda9 | ||
|
|
7e0b73deaa | ||
|
|
c0a023d13d | ||
|
|
0d3ae1c295 |
111
.agents/adding-gallery-models.md
Normal file
111
.agents/adding-gallery-models.md
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# Adding GGUF Models from HuggingFace to the Gallery
|
||||||
|
|
||||||
|
When adding a GGUF model from HuggingFace to the LocalAI model gallery, follow this guide.
|
||||||
|
|
||||||
|
## Gallery file
|
||||||
|
|
||||||
|
All models are defined in `gallery/index.yaml`. Find the appropriate section (embedding models near other embeddings, chat models near similar chat models) and add a new entry.
|
||||||
|
|
||||||
|
## Getting the SHA256
|
||||||
|
|
||||||
|
GGUF files on HuggingFace expose their SHA256 via the `x-linked-etag` HTTP header. Fetch it with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -sI "https://huggingface.co/<org>/<repo>/resolve/main/<filename>.gguf" | grep -i x-linked-etag
|
||||||
|
```
|
||||||
|
|
||||||
|
The value (without quotes) is the SHA256 hash. Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -sI "https://huggingface.co/ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/resolve/main/embeddinggemma-300m-qat-Q8_0.gguf" | grep -i x-linked-etag
|
||||||
|
# x-linked-etag: "6fa0c02a9c302be6f977521d399b4de3a46310a4f2621ee0063747881b673f67"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important**: Pay attention to exact filename casing — HuggingFace filenames are case-sensitive (e.g., `Q8_0` vs `q8_0`). Check the repo's file listing to get the exact name.
|
||||||
|
|
||||||
|
## Entry format — Embedding models
|
||||||
|
|
||||||
|
Embedding models use `gallery/virtual.yaml` as the base config and set `embeddings: true`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- name: "model-name"
|
||||||
|
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/<original-model-org>/<original-model-name>
|
||||||
|
- https://huggingface.co/<gguf-org>/<gguf-repo-name>
|
||||||
|
description: |
|
||||||
|
Short description of the model, its size, and capabilities.
|
||||||
|
tags:
|
||||||
|
- embeddings
|
||||||
|
overrides:
|
||||||
|
backend: llama-cpp
|
||||||
|
embeddings: true
|
||||||
|
parameters:
|
||||||
|
model: <filename>.gguf
|
||||||
|
files:
|
||||||
|
- filename: <filename>.gguf
|
||||||
|
uri: huggingface://<gguf-org>/<gguf-repo-name>/<filename>.gguf
|
||||||
|
sha256: <sha256-hash>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Entry format — Chat/LLM models
|
||||||
|
|
||||||
|
Chat models typically reference a template config (e.g., `gallery/gemma.yaml`, `gallery/chatml.yaml`) that defines the prompt format. Use YAML anchors (`&name` / `*name`) if adding multiple quantization variants of the same model:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- &model-anchor
|
||||||
|
url: "github:mudler/LocalAI/gallery/<template>.yaml@master"
|
||||||
|
name: "model-name"
|
||||||
|
icon: https://example.com/icon.png
|
||||||
|
license: <license>
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/<org>/<model>
|
||||||
|
- https://huggingface.co/<gguf-org>/<gguf-repo>
|
||||||
|
description: |
|
||||||
|
Model description.
|
||||||
|
tags:
|
||||||
|
- llm
|
||||||
|
- gguf
|
||||||
|
- gpu
|
||||||
|
- cpu
|
||||||
|
overrides:
|
||||||
|
parameters:
|
||||||
|
model: <filename>-Q4_K_M.gguf
|
||||||
|
files:
|
||||||
|
- filename: <filename>-Q4_K_M.gguf
|
||||||
|
sha256: <sha256>
|
||||||
|
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q4_K_M.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
To add a variant (e.g., different quantization), use YAML merge:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- !!merge <<: *model-anchor
|
||||||
|
name: "model-name-q8"
|
||||||
|
overrides:
|
||||||
|
parameters:
|
||||||
|
model: <filename>-Q8_0.gguf
|
||||||
|
files:
|
||||||
|
- filename: <filename>-Q8_0.gguf
|
||||||
|
sha256: <sha256>
|
||||||
|
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q8_0.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available template configs
|
||||||
|
|
||||||
|
Look at existing `.yaml` files in `gallery/` to find the right prompt template for your model architecture:
|
||||||
|
|
||||||
|
- `gemma.yaml` — Gemma-family models (gemma, embeddinggemma, etc.)
|
||||||
|
- `chatml.yaml` — ChatML format (many Mistral/OpenHermes models)
|
||||||
|
- `deepseek.yaml` — DeepSeek models
|
||||||
|
- `virtual.yaml` — Minimal base (good for embedding models that don't need chat templates)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
1. **Find the GGUF file** on HuggingFace — note exact filename (case-sensitive)
|
||||||
|
2. **Get the SHA256** using the `curl -sI` + `x-linked-etag` method above
|
||||||
|
3. **Choose the right template** config from `gallery/` based on model architecture
|
||||||
|
4. **Add the entry** to `gallery/index.yaml` near similar models
|
||||||
|
5. **Set `embeddings: true`** if it's an embedding model
|
||||||
|
6. **Include both URLs** — the original model page and the GGUF repo
|
||||||
|
7. **Write a description** — mention model size, capabilities, and quantization type
|
||||||
170
.github/workflows/backend.yml
vendored
170
.github/workflows/backend.yml
vendored
@@ -105,6 +105,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-faster-whisper'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "faster-whisper"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
@@ -561,6 +574,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "8"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-12-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
cuda-minor-version: "8"
|
cuda-minor-version: "8"
|
||||||
@@ -965,6 +991,32 @@ jobs:
|
|||||||
backend: "mlx-distributed"
|
backend: "mlx-distributed"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisperx'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
backend: "whisperx"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-whisper'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
backend: "faster-whisper"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "13"
|
cuda-major-version: "13"
|
||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
@@ -1108,6 +1160,32 @@ jobs:
|
|||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
context: "./"
|
context: "./"
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-13-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
skip-drivers: 'false'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-cuda-13-arm64-sam3-cpp'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "13"
|
cuda-major-version: "13"
|
||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
@@ -1644,6 +1722,32 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2204'
|
ubuntu-version: '2204'
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-whisperx'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "whisperx"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2204'
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-faster-whisper'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "faster-whisper"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2204'
|
||||||
# SYCL additional backends
|
# SYCL additional backends
|
||||||
- build-type: 'intel'
|
- build-type: 'intel'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -1842,6 +1946,59 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
# sam3-cpp
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f16-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'vulkan'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64,linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-vulkan-sam3-cpp'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'sycl_f32'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
@@ -1894,6 +2051,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2204'
|
ubuntu-version: '2204'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
skip-drivers: 'false'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-arm64-sam3-cpp'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
backend: "sam3-cpp"
|
||||||
|
dockerfile: "./backend/Dockerfile.golang"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2204'
|
||||||
# whisper
|
# whisper
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
|
|||||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -34,6 +34,10 @@ jobs:
|
|||||||
variable: "ACESTEP_CPP_VERSION"
|
variable: "ACESTEP_CPP_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
file: "backend/go/acestep-cpp/Makefile"
|
file: "backend/go/acestep-cpp/Makefile"
|
||||||
|
- repository: "PABannier/sam3.cpp"
|
||||||
|
variable: "SAM3_VERSION"
|
||||||
|
branch: "main"
|
||||||
|
file: "backend/go/sam3-cpp/Makefile"
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
|
|||||||
23
.github/workflows/test-extra.yml
vendored
23
.github/workflows/test-extra.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
|||||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||||
|
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
@@ -528,3 +529,25 @@ jobs:
|
|||||||
- name: Test voxtral
|
- name: Test voxtral
|
||||||
run: |
|
run: |
|
||||||
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
||||||
|
tests-kokoros:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.kokoros == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y build-essential cmake pkg-config protobuf-compiler clang libclang-dev
|
||||||
|
sudo apt-get install -y espeak-ng libespeak-ng-dev libsonic-dev libpcaudio-dev libopus-dev libssl-dev
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||||
|
- name: Build kokoros
|
||||||
|
run: |
|
||||||
|
make -C backend/rust/kokoros kokoros-grpc
|
||||||
|
- name: Test kokoros
|
||||||
|
run: |
|
||||||
|
make -C backend/rust/kokoros test
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
|||||||
[submodule "docs/themes/hugo-theme-relearn"]
|
[submodule "docs/themes/hugo-theme-relearn"]
|
||||||
path = docs/themes/hugo-theme-relearn
|
path = docs/themes/hugo-theme-relearn
|
||||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
url = https://github.com/McShelby/hugo-theme-relearn.git
|
||||||
|
[submodule "backend/rust/kokoros/sources/Kokoros"]
|
||||||
|
path = backend/rust/kokoros/sources/Kokoros
|
||||||
|
url = https://github.com/lucasjinreal/Kokoros
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
|||||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||||
|
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||||
|
|
||||||
## Quick Reference
|
## Quick Reference
|
||||||
|
|
||||||
|
|||||||
17
Makefile
17
Makefile
@@ -1,5 +1,5 @@
|
|||||||
# Disable parallel execution for backend builds
|
# Disable parallel execution for backend builds
|
||||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp
|
||||||
|
|
||||||
GOCMD=go
|
GOCMD=go
|
||||||
GOTEST=$(GOCMD) test
|
GOTEST=$(GOCMD) test
|
||||||
@@ -148,7 +148,6 @@ test-models/testmodel.ggml:
|
|||||||
mkdir -p test-dir
|
mkdir -p test-dir
|
||||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
|
||||||
wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||||
cp tests/models_fixtures/* test-models
|
cp tests/models_fixtures/* test-models
|
||||||
|
|
||||||
@@ -429,9 +428,11 @@ prepare-test-extra: protogen-python
|
|||||||
$(MAKE) -C backend/python/qwen-asr
|
$(MAKE) -C backend/python/qwen-asr
|
||||||
$(MAKE) -C backend/python/nemo
|
$(MAKE) -C backend/python/nemo
|
||||||
$(MAKE) -C backend/python/voxcpm
|
$(MAKE) -C backend/python/voxcpm
|
||||||
|
$(MAKE) -C backend/python/faster-whisper
|
||||||
$(MAKE) -C backend/python/whisperx
|
$(MAKE) -C backend/python/whisperx
|
||||||
$(MAKE) -C backend/python/ace-step
|
$(MAKE) -C backend/python/ace-step
|
||||||
$(MAKE) -C backend/python/trl
|
$(MAKE) -C backend/python/trl
|
||||||
|
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
test-extra: prepare-test-extra
|
||||||
$(MAKE) -C backend/python/transformers test
|
$(MAKE) -C backend/python/transformers test
|
||||||
@@ -449,9 +450,11 @@ test-extra: prepare-test-extra
|
|||||||
$(MAKE) -C backend/python/qwen-asr test
|
$(MAKE) -C backend/python/qwen-asr test
|
||||||
$(MAKE) -C backend/python/nemo test
|
$(MAKE) -C backend/python/nemo test
|
||||||
$(MAKE) -C backend/python/voxcpm test
|
$(MAKE) -C backend/python/voxcpm test
|
||||||
|
$(MAKE) -C backend/python/faster-whisper test
|
||||||
$(MAKE) -C backend/python/whisperx test
|
$(MAKE) -C backend/python/whisperx test
|
||||||
$(MAKE) -C backend/python/ace-step test
|
$(MAKE) -C backend/python/ace-step test
|
||||||
$(MAKE) -C backend/python/trl test
|
$(MAKE) -C backend/python/trl test
|
||||||
|
$(MAKE) -C backend/rust/kokoros test
|
||||||
|
|
||||||
DOCKER_IMAGE?=local-ai
|
DOCKER_IMAGE?=local-ai
|
||||||
IMAGE_TYPE?=core
|
IMAGE_TYPE?=core
|
||||||
@@ -587,6 +590,12 @@ BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
|||||||
BACKEND_TRL = trl|python|.|false|true
|
BACKEND_TRL = trl|python|.|false|true
|
||||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||||
|
|
||||||
|
# Rust backends
|
||||||
|
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||||
|
|
||||||
|
# C++ backends (Go wrapper with purego)
|
||||||
|
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||||
|
|
||||||
# Helper function to build docker image for a backend
|
# Helper function to build docker image for a backend
|
||||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||||
define docker-build-backend
|
define docker-build-backend
|
||||||
@@ -645,12 +654,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||||
|
|
||||||
# Pattern rule for docker-save targets
|
# Pattern rule for docker-save targets
|
||||||
docker-save-%: backend-images
|
docker-save-%: backend-images
|
||||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||||
|
|
||||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
### Mock Backend for E2E Tests
|
### Mock Backend for E2E Tests
|
||||||
|
|||||||
31
README.md
31
README.md
@@ -42,16 +42,38 @@ Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
|||||||
|
|
||||||
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
||||||
|
|
||||||
## Screenshots
|
## Guided tour
|
||||||
|
|
||||||
### Chat, Model gallery
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||||
|
|
||||||
### Agents
|
<details>
|
||||||
|
|
||||||
|
<summary>
|
||||||
|
Click to see more!
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
#### User and auth
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/228fa9ad-81a3-4d43-bfb9-31557e14a36c
|
||||||
|
|
||||||
|
#### Agents
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||||
|
|
||||||
|
#### Usage metrics per user
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/cbb03379-23b4-4e3d-bd26-d152f057007f
|
||||||
|
|
||||||
|
#### Fine-tuning and Quantization
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/5ba4ace9-d3df-4795-b7d4-b0b404ea71ee
|
||||||
|
|
||||||
|
#### WebRTC
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/ed88e34c-fed3-4b83-8a67-4716a9feeb7b
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
@@ -174,6 +196,7 @@ See the full [Backend & Model Compatibility Table](https://localai.io/model-comp
|
|||||||
- [Build from source](https://localai.io/basics/build/)
|
- [Build from source](https://localai.io/basics/build/)
|
||||||
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||||
|
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||||
|
|
||||||
|
|||||||
39
backend/Dockerfile.rust
Normal file
39
backend/Dockerfile.rust
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
ARG BASE_IMAGE=ubuntu:24.04
|
||||||
|
|
||||||
|
FROM ${BASE_IMAGE} AS builder
|
||||||
|
ARG BACKEND=kokoros
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ARG TARGETARCH
|
||||||
|
ARG TARGETVARIANT
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
git ccache \
|
||||||
|
ca-certificates \
|
||||||
|
make cmake wget \
|
||||||
|
curl unzip \
|
||||||
|
clang \
|
||||||
|
pkg-config \
|
||||||
|
libssl-dev \
|
||||||
|
espeak-ng libespeak-ng-dev \
|
||||||
|
libsonic-dev libpcaudio-dev \
|
||||||
|
libopus-dev \
|
||||||
|
protobuf-compiler && \
|
||||||
|
apt-get clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Rust
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
|
||||||
|
COPY . /LocalAI
|
||||||
|
|
||||||
|
RUN git config --global --add safe.directory /LocalAI
|
||||||
|
|
||||||
|
RUN make -C /LocalAI/backend/rust/${BACKEND} build
|
||||||
|
|
||||||
|
FROM scratch
|
||||||
|
ARG BACKEND=kokoros
|
||||||
|
|
||||||
|
COPY --from=builder /LocalAI/backend/rust/${BACKEND}/package/. ./
|
||||||
@@ -444,6 +444,10 @@ message Message {
|
|||||||
|
|
||||||
message DetectOptions {
|
message DetectOptions {
|
||||||
string src = 1;
|
string src = 1;
|
||||||
|
string prompt = 2; // Text prompt (for SAM 3 PCS mode)
|
||||||
|
repeated float points = 3; // Point coordinates as [x1, y1, label1, x2, y2, label2, ...] (label: 1=pos, 0=neg)
|
||||||
|
repeated float boxes = 4; // Box coordinates as [x1, y1, x2, y2, ...]
|
||||||
|
float threshold = 5; // Detection confidence threshold
|
||||||
}
|
}
|
||||||
|
|
||||||
message Detection {
|
message Detection {
|
||||||
@@ -453,6 +457,7 @@ message Detection {
|
|||||||
float height = 4;
|
float height = 4;
|
||||||
float confidence = 5;
|
float confidence = 5;
|
||||||
string class_name = 6;
|
string class_name = 6;
|
||||||
|
bytes mask = 7; // PNG-encoded binary segmentation mask
|
||||||
}
|
}
|
||||||
|
|
||||||
message DetectResponse {
|
message DetectResponse {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=95a6ebabb277c4cc18247e7bc2a5502133caca63
|
LLAMA_VERSION?=e62fa13c2497b2cd1958cb496e9489e86bbd5182
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -40,45 +40,41 @@ using grpc::ServerBuilder;
|
|||||||
using grpc::ServerContext;
|
using grpc::ServerContext;
|
||||||
using grpc::Status;
|
using grpc::Status;
|
||||||
|
|
||||||
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
// gRPC bearer token auth for distributed mode.
|
||||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||||
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
|
||||||
public:
|
|
||||||
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
|
||||||
|
|
||||||
bool IsBlocking() const override { return false; }
|
// Cached auth token — empty means auth is disabled.
|
||||||
|
static std::string g_grpc_auth_token;
|
||||||
|
|
||||||
grpc::Status Process(const InputMetadata& auth_metadata,
|
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||||
grpc::AuthContext* /*context*/,
|
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||||
OutputMetadata* /*consumed_auth_metadata*/,
|
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||||
OutputMetadata* /*response_metadata*/) override {
|
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||||
auto it = auth_metadata.find("authorization");
|
unsigned char result = 0;
|
||||||
if (it != auth_metadata.end()) {
|
for (size_t i = 0; i < n; i++) {
|
||||||
std::string expected = "Bearer " + token_;
|
result |= pa[i] ^ pb[i];
|
||||||
std::string got(it->second.data(), it->second.size());
|
|
||||||
// Constant-time comparison
|
|
||||||
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
|
||||||
return grpc::Status::OK;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
|
||||||
}
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
// Returns OK when auth is disabled or the token matches.
|
||||||
std::string token_;
|
static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||||
|
if (g_grpc_auth_token.empty()) {
|
||||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
return grpc::Status::OK;
|
||||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
|
||||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
|
||||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
|
||||||
unsigned char result = 0;
|
|
||||||
for (size_t i = 0; i < n; i++) {
|
|
||||||
result |= pa[i] ^ pb[i];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
};
|
auto metadata = context->client_metadata();
|
||||||
|
auto it = metadata.find("authorization");
|
||||||
|
if (it != metadata.end()) {
|
||||||
|
std::string expected = "Bearer " + g_grpc_auth_token;
|
||||||
|
std::string got(it->second.data(), it->second.size());
|
||||||
|
if (expected.size() == got.size() &&
|
||||||
|
ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||||
|
}
|
||||||
|
|
||||||
// END LocalAI
|
// END LocalAI
|
||||||
|
|
||||||
@@ -288,6 +284,12 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
|||||||
data["ignore_eos"] = predict->ignoreeos();
|
data["ignore_eos"] = predict->ignoreeos();
|
||||||
data["embeddings"] = predict->embeddings();
|
data["embeddings"] = predict->embeddings();
|
||||||
|
|
||||||
|
// Speculative decoding per-request overrides
|
||||||
|
// NDraft maps to speculative.n_max (maximum draft tokens per speculation step)
|
||||||
|
if (predict->ndraft() > 0) {
|
||||||
|
data["speculative.n_max"] = predict->ndraft();
|
||||||
|
}
|
||||||
|
|
||||||
// Add the correlationid to json data
|
// Add the correlationid to json data
|
||||||
data["correlation_id"] = predict->correlationid();
|
data["correlation_id"] = predict->correlationid();
|
||||||
|
|
||||||
@@ -406,6 +408,16 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
if (!request->mmproj().empty()) {
|
if (!request->mmproj().empty()) {
|
||||||
params.mmproj.path = request->mmproj();
|
params.mmproj.path = request->mmproj();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Draft model for speculative decoding
|
||||||
|
if (!request->draftmodel().empty()) {
|
||||||
|
params.speculative.mparams_dft.path = request->draftmodel();
|
||||||
|
// Default to draft type if a draft model is set but no explicit type
|
||||||
|
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||||
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// params.model_alias ??
|
// params.model_alias ??
|
||||||
params.model_alias.insert(request->modelfile());
|
params.model_alias.insert(request->modelfile());
|
||||||
if (!request->cachetypekey().empty()) {
|
if (!request->cachetypekey().empty()) {
|
||||||
@@ -613,6 +625,48 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// If conversion fails, keep default value (8)
|
// If conversion fails, keep default value (8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Speculative decoding options
|
||||||
|
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||||
|
auto type = common_speculative_type_from_name(optval_str);
|
||||||
|
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||||
|
params.speculative.type = type;
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.n_max = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_n_min") || !strcmp(optname, "draft_min")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.n_min = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_p_min") || !strcmp(optname, "draft_p_min")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.p_min = std::stof(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_p_split")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.p_split = std::stof(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_ngram_size_n") || !strcmp(optname, "ngram_size_n")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.ngram_size_n = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_ngram_size_m") || !strcmp(optname, "ngram_size_m")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.ngram_size_m = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "spec_ngram_min_hits") || !strcmp(optname, "ngram_min_hits")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.ngram_min_hits = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "draft_gpu_layers")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.n_gpu_layers = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "draft_ctx_size")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.speculative.n_ctx = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -757,13 +811,17 @@ private:
|
|||||||
public:
|
public:
|
||||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||||
|
|
||||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
// Implement Health RPC
|
// Implement Health RPC
|
||||||
reply->set_message("OK");
|
reply->set_message("OK");
|
||||||
return Status::OK;
|
return Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
// Implement LoadModel RPC
|
// Implement LoadModel RPC
|
||||||
common_params params;
|
common_params params;
|
||||||
params_parse(ctx_server, request, params);
|
params_parse(ctx_server, request, params);
|
||||||
@@ -962,6 +1020,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
if (params_base.model.path.empty()) {
|
if (params_base.model.path.empty()) {
|
||||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
}
|
}
|
||||||
@@ -1249,6 +1309,7 @@ public:
|
|||||||
|
|
||||||
body_json["messages"] = messages_json;
|
body_json["messages"] = messages_json;
|
||||||
body_json["stream"] = true; // PredictStream is always streaming
|
body_json["stream"] = true; // PredictStream is always streaming
|
||||||
|
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||||
|
|
||||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||||
@@ -1553,11 +1614,15 @@ public:
|
|||||||
ctx_server.impl->vocab,
|
ctx_server.impl->vocab,
|
||||||
params_base,
|
params_base,
|
||||||
ctx_server.get_meta().slot_n_ctx,
|
ctx_server.get_meta().slot_n_ctx,
|
||||||
|
ctx_server.get_meta().logit_bias_eog,
|
||||||
data);
|
data);
|
||||||
task.id_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||||
|
// Without this, the PEG parser never produces diffs and the Go side
|
||||||
|
// cannot detect tool calls or separate reasoning from content.
|
||||||
|
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
@@ -1582,19 +1647,47 @@ public:
|
|||||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||||
|
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||||
|
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||||
backend::Reply reply;
|
backend::Reply reply;
|
||||||
std::string completion_text = res_json.value("content", "");
|
std::string completion_text;
|
||||||
reply.set_message(completion_text);
|
|
||||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
|
||||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
|
||||||
|
|
||||||
|
if (res_json.contains("choices")) {
|
||||||
|
// OAI chat format — extract content from choices[0].delta
|
||||||
|
const auto & choices = res_json.at("choices");
|
||||||
|
if (!choices.empty()) {
|
||||||
|
const auto & delta = choices[0].value("delta", json::object());
|
||||||
|
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||||
|
completion_text = delta.at("content").get<std::string>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Native llama.cpp format
|
||||||
|
completion_text = res_json.value("content", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
reply.set_message(completion_text);
|
||||||
|
|
||||||
|
// Token counts: native format has top-level fields,
|
||||||
|
// OAI format has them in "usage" (final chunk only)
|
||||||
|
if (res_json.contains("usage")) {
|
||||||
|
const auto & usage = res_json.at("usage");
|
||||||
|
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||||
|
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||||
|
} else {
|
||||||
|
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||||
|
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timings: present as top-level "timings" in both formats
|
||||||
if (res_json.contains("timings")) {
|
if (res_json.contains("timings")) {
|
||||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logprobs: extract_logprobs_from_json handles both formats
|
||||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||||
reply.set_logprobs(logprobs_json.dump());
|
reply.set_logprobs(logprobs_json.dump());
|
||||||
@@ -1603,6 +1696,12 @@ public:
|
|||||||
return reply;
|
return reply;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Attach chat deltas from the autoparser to a Reply.
|
||||||
|
// When diffs are available, populate ChatDeltas on the reply.
|
||||||
|
// The raw message is always preserved so the Go side can use it
|
||||||
|
// for reasoning extraction and tool call parsing as a fallback
|
||||||
|
// (important in distributed mode where ChatDeltas may not be
|
||||||
|
// the primary parsing path).
|
||||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||||
// Try streaming partial result first
|
// Try streaming partial result first
|
||||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||||
@@ -1617,12 +1716,23 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Process first result
|
// Process first result.
|
||||||
|
// When TASK_RESPONSE_TYPE_OAI_CHAT is used, the first token may
|
||||||
|
// produce a JSON array with a role-init element followed by the
|
||||||
|
// actual content element. We must only attach chat deltas to the
|
||||||
|
// content element — attaching to both would duplicate the first
|
||||||
|
// token since oaicompat_msg_diffs is the same for both.
|
||||||
json first_res_json = first_result->to_json();
|
json first_res_json = first_result->to_json();
|
||||||
if (first_res_json.is_array()) {
|
if (first_res_json.is_array()) {
|
||||||
for (const auto & res : first_res_json) {
|
for (const auto & res : first_res_json) {
|
||||||
auto reply = build_reply_from_json(res, first_result.get());
|
auto reply = build_reply_from_json(res, first_result.get());
|
||||||
attach_chat_deltas(reply, first_result.get());
|
// Skip chat deltas for role-init elements (have "role" in
|
||||||
|
// delta but no content/reasoning diffs of their own).
|
||||||
|
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||||
|
res["choices"][0].value("delta", json::object()).contains("role");
|
||||||
|
if (!is_role_init) {
|
||||||
|
attach_chat_deltas(reply, first_result.get());
|
||||||
|
}
|
||||||
writer->Write(reply);
|
writer->Write(reply);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1646,7 +1756,11 @@ public:
|
|||||||
if (res_json.is_array()) {
|
if (res_json.is_array()) {
|
||||||
for (const auto & res : res_json) {
|
for (const auto & res : res_json) {
|
||||||
auto reply = build_reply_from_json(res, result.get());
|
auto reply = build_reply_from_json(res, result.get());
|
||||||
attach_chat_deltas(reply, result.get());
|
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||||
|
res["choices"][0].value("delta", json::object()).contains("role");
|
||||||
|
if (!is_role_init) {
|
||||||
|
attach_chat_deltas(reply, result.get());
|
||||||
|
}
|
||||||
writer->Write(reply);
|
writer->Write(reply);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1665,6 +1779,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
if (params_base.model.path.empty()) {
|
if (params_base.model.path.empty()) {
|
||||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
}
|
}
|
||||||
@@ -2282,11 +2398,13 @@ public:
|
|||||||
ctx_server.impl->vocab,
|
ctx_server.impl->vocab,
|
||||||
params_base,
|
params_base,
|
||||||
ctx_server.get_meta().slot_n_ctx,
|
ctx_server.get_meta().slot_n_ctx,
|
||||||
|
ctx_server.get_meta().logit_bias_eog,
|
||||||
data);
|
data);
|
||||||
task.id_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||||
|
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
@@ -2317,25 +2435,48 @@ public:
|
|||||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||||
GGML_ASSERT(final_res != nullptr);
|
GGML_ASSERT(final_res != nullptr);
|
||||||
json result_json = all_results.results[0]->to_json();
|
json result_json = all_results.results[0]->to_json();
|
||||||
reply->set_message(result_json.value("content", ""));
|
|
||||||
|
|
||||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||||
|
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||||
|
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||||
|
std::string completion_text;
|
||||||
|
int32_t tokens_predicted = 0;
|
||||||
|
int32_t tokens_evaluated = 0;
|
||||||
|
|
||||||
|
if (result_json.contains("choices")) {
|
||||||
|
// OAI chat format
|
||||||
|
const auto & choices = result_json.at("choices");
|
||||||
|
if (!choices.empty()) {
|
||||||
|
const auto & msg = choices[0].value("message", json::object());
|
||||||
|
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||||
|
completion_text = msg.at("content").get<std::string>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (result_json.contains("usage")) {
|
||||||
|
const auto & usage = result_json.at("usage");
|
||||||
|
tokens_predicted = usage.value("completion_tokens", 0);
|
||||||
|
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Native llama.cpp format
|
||||||
|
completion_text = result_json.value("content", "");
|
||||||
|
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||||
|
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||||
|
}
|
||||||
|
reply->set_message(completion_text);
|
||||||
reply->set_tokens(tokens_predicted);
|
reply->set_tokens(tokens_predicted);
|
||||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
|
||||||
reply->set_prompt_tokens(tokens_evaluated);
|
reply->set_prompt_tokens(tokens_evaluated);
|
||||||
|
|
||||||
|
// Timings: present in both formats as a top-level "timings" object
|
||||||
if (result_json.contains("timings")) {
|
if (result_json.contains("timings")) {
|
||||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
|
||||||
reply->set_timing_token_generation(timing_token_generation);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set logprobs if present
|
// Logprobs: extract_logprobs_from_json handles both formats
|
||||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||||
std::string logprobs_str = logprobs_json.dump();
|
reply->set_logprobs(logprobs_json.dump());
|
||||||
reply->set_logprobs(logprobs_str);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate chat deltas from the autoparser's final parsed message
|
// Populate chat deltas from the autoparser's final parsed message
|
||||||
@@ -2351,7 +2492,20 @@ public:
|
|||||||
for (auto & res : all_results.results) {
|
for (auto & res : all_results.results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||||
json res_json = res->to_json();
|
json res_json = res->to_json();
|
||||||
arr.push_back(res_json.value("content", ""));
|
// Handle both native and OAI chat formats
|
||||||
|
std::string result_content;
|
||||||
|
if (res_json.contains("choices")) {
|
||||||
|
const auto & choices = res_json.at("choices");
|
||||||
|
if (!choices.empty()) {
|
||||||
|
const auto & msg = choices[0].value("message", json::object());
|
||||||
|
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||||
|
result_content = msg.at("content").get<std::string>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result_content = res_json.value("content", "");
|
||||||
|
}
|
||||||
|
arr.push_back(result_content);
|
||||||
|
|
||||||
// Extract logprobs for each result
|
// Extract logprobs for each result
|
||||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||||
@@ -2383,6 +2537,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
if (params_base.model.path.empty()) {
|
if (params_base.model.path.empty()) {
|
||||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
}
|
}
|
||||||
@@ -2563,7 +2719,9 @@ public:
|
|||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
if (params_base.model.path.empty()) {
|
if (params_base.model.path.empty()) {
|
||||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
}
|
}
|
||||||
@@ -2803,19 +2961,14 @@ int main(int argc, char** argv) {
|
|||||||
BackendServiceImpl service(ctx_server);
|
BackendServiceImpl service(ctx_server);
|
||||||
|
|
||||||
ServerBuilder builder;
|
ServerBuilder builder;
|
||||||
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
|
||||||
std::shared_ptr<grpc::ServerCredentials> creds;
|
|
||||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
|
||||||
creds = grpc::InsecureServerCredentials();
|
|
||||||
creds->SetAuthMetadataProcessor(
|
|
||||||
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
|
||||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
|
||||||
} else {
|
|
||||||
creds = grpc::InsecureServerCredentials();
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.AddListeningPort(server_address, creds);
|
// Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||||
|
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||||
|
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||||
|
g_grpc_auth_token = auth_token;
|
||||||
|
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||||
|
}
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# acestep.cpp version
|
# acestep.cpp version
|
||||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||||
ACESTEP_CPP_VERSION?=6f35c874ee11e86d511b860019b84976f5b52d3a
|
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||||
SO_TARGET?=libgoacestepcpp.so
|
SO_TARGET?=libgoacestepcpp.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
sources/
|
||||||
|
build*/
|
||||||
|
package/
|
||||||
|
libgosam3*.so
|
||||||
|
sam3-cpp
|
||||||
|
test-models/
|
||||||
|
test-data/
|
||||||
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14)
|
||||||
|
project(gosam3 LANGUAGES C CXX)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
# Build ggml as static libraries to avoid runtime .so dependencies
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||||
|
|
||||||
|
set(SAM3_BUILD_EXAMPLES OFF CACHE BOOL "Disable sam3.cpp examples" FORCE)
|
||||||
|
set(SAM3_BUILD_TESTS OFF CACHE BOOL "Disable sam3.cpp tests" FORCE)
|
||||||
|
|
||||||
|
add_subdirectory(./sources/sam3.cpp)
|
||||||
|
|
||||||
|
add_library(gosam3 MODULE gosam3.cpp)
|
||||||
|
target_link_libraries(gosam3 PRIVATE sam3 ggml)
|
||||||
|
|
||||||
|
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||||
|
target_link_libraries(gosam3 PRIVATE stdc++fs)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_include_directories(gosam3 PUBLIC
|
||||||
|
sources/sam3.cpp
|
||||||
|
sources/sam3.cpp/ggml/include
|
||||||
|
)
|
||||||
|
|
||||||
|
set_property(TARGET gosam3 PROPERTY CXX_STANDARD 14)
|
||||||
|
set_target_properties(gosam3 PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||||
122
backend/go/sam3-cpp/Makefile
Normal file
122
backend/go/sam3-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
CMAKE_ARGS?=
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc --ignore=1)
|
||||||
|
|
||||||
|
# sam3.cpp
|
||||||
|
SAM3_REPO?=https://github.com/PABannier/sam3.cpp
|
||||||
|
SAM3_VERSION?=01832ef85fcc8eb6488f1d01cd247f07e96ff5a9
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
ROCM_HOME ?= /opt/rocm
|
||||||
|
ROCM_PATH ?= /opt/rocm
|
||||||
|
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||||
|
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||||
|
AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||||
|
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||||
|
else ifeq ($(OS),Darwin)
|
||||||
|
ifneq ($(BUILD_TYPE),metal)
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||||
|
else
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx \
|
||||||
|
-DGGML_SYCL_F16=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx
|
||||||
|
endif
|
||||||
|
|
||||||
|
sources/sam3.cpp:
|
||||||
|
git clone --recursive $(SAM3_REPO) sources/sam3.cpp && \
|
||||||
|
cd sources/sam3.cpp && \
|
||||||
|
git checkout $(SAM3_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
UNAME_S := $(shell uname -s)
|
||||||
|
|
||||||
|
# Only build CPU variants on Linux
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
VARIANT_TARGETS = libgosam3-avx.so libgosam3-avx2.so libgosam3-avx512.so libgosam3-fallback.so
|
||||||
|
else
|
||||||
|
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||||
|
VARIANT_TARGETS = libgosam3-fallback.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
sam3-cpp: main.go gosam3.go $(VARIANT_TARGETS)
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o sam3-cpp ./
|
||||||
|
|
||||||
|
package: sam3-cpp
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf libgosam3*.so sam3-cpp package sources
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf build*
|
||||||
|
|
||||||
|
# Build all variants (Linux only)
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
libgosam3-avx.so: sources/sam3.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I sam3-cpp build info:avx${RESET})
|
||||||
|
SO_TARGET=libgosam3-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgosam3-avx2.so: sources/sam3.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I sam3-cpp build info:avx2${RESET})
|
||||||
|
SO_TARGET=libgosam3-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgosam3-avx512.so: sources/sam3.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I sam3-cpp build info:avx512${RESET})
|
||||||
|
SO_TARGET=libgosam3-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||||
|
rm -rfv build*
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Build fallback variant (all platforms)
|
||||||
|
libgosam3-fallback.so: sources/sam3.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I sam3-cpp build info:fallback${RESET})
|
||||||
|
SO_TARGET=libgosam3-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgosam3-custom: CMakeLists.txt gosam3.cpp gosam3.h
|
||||||
|
mkdir -p build-$(SO_TARGET) && \
|
||||||
|
cd build-$(SO_TARGET) && \
|
||||||
|
cmake .. $(CMAKE_ARGS) && \
|
||||||
|
cmake --build . --config Release -j$(JOBS) && \
|
||||||
|
cd .. && \
|
||||||
|
mv build-$(SO_TARGET)/libgosam3.so ./$(SO_TARGET)
|
||||||
|
|
||||||
|
all: sam3-cpp package
|
||||||
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#include "sam3.h"
|
||||||
|
#include "gosam3.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_WRITE_STATIC
|
||||||
|
#include "stb_image_write.h"
|
||||||
|
|
||||||
|
// Static state
|
||||||
|
static std::shared_ptr<sam3_model> g_model;
|
||||||
|
static sam3_state_ptr g_state;
|
||||||
|
static sam3_result g_result;
|
||||||
|
static std::vector<std::vector<unsigned char>> g_mask_pngs;
|
||||||
|
|
||||||
|
// Callback for stbi_write_png_to_mem via stbi_write_png_to_func
|
||||||
|
static void png_write_callback(void *context, void *data, int size) {
|
||||||
|
auto *buf = static_cast<std::vector<unsigned char>*>(context);
|
||||||
|
auto *bytes = static_cast<unsigned char*>(data);
|
||||||
|
buf->insert(buf->end(), bytes, bytes + size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode all masks as PNGs after segmentation
|
||||||
|
static void encode_masks_as_png() {
|
||||||
|
g_mask_pngs.clear();
|
||||||
|
g_mask_pngs.resize(g_result.detections.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < g_result.detections.size(); i++) {
|
||||||
|
const auto &mask = g_result.detections[i].mask;
|
||||||
|
if (mask.width > 0 && mask.height > 0 && !mask.data.empty()) {
|
||||||
|
stbi_write_png_to_func(png_write_callback, &g_mask_pngs[i],
|
||||||
|
mask.width, mask.height, 1,
|
||||||
|
mask.data.data(), mask.width);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
int sam3_cpp_load_model(const char *model_path, int threads) {
|
||||||
|
sam3_params params;
|
||||||
|
params.model_path = model_path;
|
||||||
|
params.n_threads = threads;
|
||||||
|
params.use_gpu = true;
|
||||||
|
|
||||||
|
g_model = sam3_load_model(params);
|
||||||
|
if (!g_model) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] Failed to load model: %s\n", model_path);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_state = sam3_create_state(*g_model, params);
|
||||||
|
if (!g_state) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] Failed to create state\n");
|
||||||
|
g_model.reset();
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "[sam3-cpp] Model loaded: %s (threads=%d)\n", model_path, threads);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sam3_cpp_encode_image(const char *image_path) {
|
||||||
|
if (!g_model || !g_state) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] Model not loaded\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sam3_image img = sam3_load_image(image_path);
|
||||||
|
if (img.data.empty()) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] Failed to load image: %s\n", image_path);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sam3_encode_image(*g_state, *g_model, img)) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] Failed to encode image\n");
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||||
|
float *boxes, int n_box_quads,
|
||||||
|
float threshold) {
|
||||||
|
if (!g_model || !g_state) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sam3_pvs_params pvs_params;
|
||||||
|
|
||||||
|
// Parse points: each triple is [x, y, label]
|
||||||
|
for (int i = 0; i < n_point_triples; i++) {
|
||||||
|
float x = points[i * 3];
|
||||||
|
float y = points[i * 3 + 1];
|
||||||
|
float label = points[i * 3 + 2];
|
||||||
|
sam3_point pt = {x, y};
|
||||||
|
if (label > 0.5f) {
|
||||||
|
pvs_params.pos_points.push_back(pt);
|
||||||
|
} else {
|
||||||
|
pvs_params.neg_points.push_back(pt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse boxes: each quad is [x1, y1, x2, y2], use only first box
|
||||||
|
if (n_box_quads > 0) {
|
||||||
|
pvs_params.box = {boxes[0], boxes[1], boxes[2], boxes[3]};
|
||||||
|
pvs_params.use_box = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_result = sam3_segment_pvs(*g_state, *g_model, pvs_params);
|
||||||
|
encode_masks_as_png();
|
||||||
|
|
||||||
|
return static_cast<int>(g_result.detections.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold) {
|
||||||
|
if (!g_model || !g_state) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// PCS mode requires SAM 3 (full model with text encoder)
|
||||||
|
if (sam3_is_visual_only(*g_model) ||
|
||||||
|
sam3_get_model_type(*g_model) != SAM3_MODEL_SAM3) {
|
||||||
|
fprintf(stderr, "[sam3-cpp] PCS mode requires full SAM 3 model\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sam3_pcs_params pcs_params;
|
||||||
|
pcs_params.text_prompt = text_prompt;
|
||||||
|
pcs_params.score_threshold = threshold > 0 ? threshold : 0.5f;
|
||||||
|
|
||||||
|
g_result = sam3_segment_pcs(*g_state, *g_model, pcs_params);
|
||||||
|
encode_masks_as_png();
|
||||||
|
|
||||||
|
return static_cast<int>(g_result.detections.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
int sam3_cpp_get_n_detections(void) {
|
||||||
|
return static_cast<int>(g_result.detections.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
float sam3_cpp_get_detection_x(int i) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||||
|
return g_result.detections[i].box.x0;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sam3_cpp_get_detection_y(int i) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||||
|
return g_result.detections[i].box.y0;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sam3_cpp_get_detection_w(int i) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||||
|
const auto &box = g_result.detections[i].box;
|
||||||
|
return box.x1 - box.x0;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sam3_cpp_get_detection_h(int i) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||||
|
const auto &box = g_result.detections[i].box;
|
||||||
|
return box.y1 - box.y0;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sam3_cpp_get_detection_score(int i) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||||
|
return g_result.detections[i].score;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size) {
|
||||||
|
if (i < 0 || i >= static_cast<int>(g_mask_pngs.size())) return 0;
|
||||||
|
|
||||||
|
const auto &png = g_mask_pngs[i];
|
||||||
|
int size = static_cast<int>(png.size());
|
||||||
|
|
||||||
|
if (buf == nullptr) {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int to_copy = size < buf_size ? size : buf_size;
|
||||||
|
memcpy(buf, png.data(), to_copy);
|
||||||
|
return to_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sam3_cpp_free_results(void) {
|
||||||
|
g_result.detections.clear();
|
||||||
|
g_mask_pngs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
||||||
143
backend/go/sam3-cpp/gosam3.go
Normal file
143
backend/go/sam3-cpp/gosam3.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SAM3 struct {
|
||||||
|
base.SingleThread
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
CppLoadModel func(modelPath string, threads int) int
|
||||||
|
CppEncodeImage func(imagePath string) int
|
||||||
|
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
|
||||||
|
CppSegmentPCS func(textPrompt string, threshold float32) int
|
||||||
|
CppGetNDetections func() int
|
||||||
|
CppGetDetectionX func(i int) float32
|
||||||
|
CppGetDetectionY func(i int) float32
|
||||||
|
CppGetDetectionW func(i int) float32
|
||||||
|
CppGetDetectionH func(i int) float32
|
||||||
|
CppGetDetectionScore func(i int) float32
|
||||||
|
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
|
||||||
|
CppFreeResults func()
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *SAM3) Load(opts *pb.ModelOptions) error {
|
||||||
|
modelFile := opts.ModelFile
|
||||||
|
if modelFile == "" {
|
||||||
|
modelFile = opts.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelPath string
|
||||||
|
if filepath.IsAbs(modelFile) {
|
||||||
|
modelPath = modelFile
|
||||||
|
} else {
|
||||||
|
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
threads := int(opts.Threads)
|
||||||
|
if threads <= 0 {
|
||||||
|
threads = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := CppLoadModel(modelPath, threads)
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||||
|
// Decode base64 image and write to temp file
|
||||||
|
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile, err := os.CreateTemp("", "sam3-*.png")
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
if _, err := tmpFile.Write(imgData); err != nil {
|
||||||
|
tmpFile.Close()
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
|
||||||
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
// Encode image
|
||||||
|
ret := CppEncodeImage(tmpFile.Name())
|
||||||
|
if ret != 0 {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold := opts.Threshold
|
||||||
|
if threshold <= 0 {
|
||||||
|
threshold = 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine segmentation mode
|
||||||
|
var nDetections int
|
||||||
|
if opts.Prompt != "" {
|
||||||
|
// Text-prompted segmentation (PCS mode, SAM 3 only)
|
||||||
|
nDetections = CppSegmentPCS(opts.Prompt, threshold)
|
||||||
|
} else {
|
||||||
|
// Point/box-prompted segmentation (PVS mode)
|
||||||
|
var pointsPtr uintptr
|
||||||
|
var boxesPtr uintptr
|
||||||
|
nPointTriples := len(opts.Points) / 3
|
||||||
|
nBoxQuads := len(opts.Boxes) / 4
|
||||||
|
|
||||||
|
if nPointTriples > 0 {
|
||||||
|
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
|
||||||
|
}
|
||||||
|
if nBoxQuads > 0 {
|
||||||
|
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nDetections < 0 {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer CppFreeResults()
|
||||||
|
|
||||||
|
// Build response
|
||||||
|
detections := make([]*pb.Detection, nDetections)
|
||||||
|
for i := 0; i < nDetections; i++ {
|
||||||
|
det := &pb.Detection{
|
||||||
|
X: CppGetDetectionX(i),
|
||||||
|
Y: CppGetDetectionY(i),
|
||||||
|
Width: CppGetDetectionW(i),
|
||||||
|
Height: CppGetDetectionH(i),
|
||||||
|
Confidence: CppGetDetectionScore(i),
|
||||||
|
ClassName: "segment",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get mask PNG
|
||||||
|
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
|
||||||
|
if maskSize > 0 {
|
||||||
|
maskBuf := make([]byte, maskSize)
|
||||||
|
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
||||||
|
det.Mask = maskBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
detections[i] = det
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.DetectResponse{
|
||||||
|
Detections: detections,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
51
backend/go/sam3-cpp/gosam3.h
Normal file
51
backend/go/sam3-cpp/gosam3.h
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#ifndef GOSAM3_H
|
||||||
|
#define GOSAM3_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Load model from file. Returns 0 on success, non-zero on failure.
|
||||||
|
int sam3_cpp_load_model(const char *model_path, int threads);
|
||||||
|
|
||||||
|
// Encode an image from file path. Must be called before segmentation.
|
||||||
|
// Returns 0 on success.
|
||||||
|
int sam3_cpp_encode_image(const char *image_path);
|
||||||
|
|
||||||
|
// Segment with point/box prompts (PVS mode).
|
||||||
|
// points: flat array of [x, y, label] triples (label: 1=positive, 0=negative)
|
||||||
|
// boxes: flat array of [x1, y1, x2, y2] quads
|
||||||
|
// Returns number of detections, or -1 on error.
|
||||||
|
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||||
|
float *boxes, int n_box_quads,
|
||||||
|
float threshold);
|
||||||
|
|
||||||
|
// Segment with text prompt (PCS mode, SAM 3 only).
|
||||||
|
// Returns number of detections, or -1 on error.
|
||||||
|
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold);
|
||||||
|
|
||||||
|
// Access detection results (valid after a segment call).
|
||||||
|
int sam3_cpp_get_n_detections(void);
|
||||||
|
|
||||||
|
// Get bounding box for detection i (as x, y, width, height).
|
||||||
|
float sam3_cpp_get_detection_x(int i);
|
||||||
|
float sam3_cpp_get_detection_y(int i);
|
||||||
|
float sam3_cpp_get_detection_w(int i);
|
||||||
|
float sam3_cpp_get_detection_h(int i);
|
||||||
|
|
||||||
|
// Get confidence score for detection i.
|
||||||
|
float sam3_cpp_get_detection_score(int i);
|
||||||
|
|
||||||
|
// Get mask as PNG-encoded bytes.
|
||||||
|
// If buf is NULL, returns the required buffer size.
|
||||||
|
// Otherwise writes up to buf_size bytes and returns bytes written.
|
||||||
|
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size);
|
||||||
|
|
||||||
|
// Free current detection results.
|
||||||
|
void sam3_cpp_free_results(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // GOSAM3_H
|
||||||
56
backend/go/sam3-cpp/main.go
Normal file
56
backend/go/sam3-cpp/main.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Get library name from environment variable, default to fallback
|
||||||
|
libName := os.Getenv("SAM3_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./libgosam3-fallback.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
gosamLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
libFuncs := []LibFuncs{
|
||||||
|
{&CppLoadModel, "sam3_cpp_load_model"},
|
||||||
|
{&CppEncodeImage, "sam3_cpp_encode_image"},
|
||||||
|
{&CppSegmentPVS, "sam3_cpp_segment_pvs"},
|
||||||
|
{&CppSegmentPCS, "sam3_cpp_segment_pcs"},
|
||||||
|
{&CppGetNDetections, "sam3_cpp_get_n_detections"},
|
||||||
|
{&CppGetDetectionX, "sam3_cpp_get_detection_x"},
|
||||||
|
{&CppGetDetectionY, "sam3_cpp_get_detection_y"},
|
||||||
|
{&CppGetDetectionW, "sam3_cpp_get_detection_w"},
|
||||||
|
{&CppGetDetectionH, "sam3_cpp_get_detection_h"},
|
||||||
|
{&CppGetDetectionScore, "sam3_cpp_get_detection_score"},
|
||||||
|
{&CppGetDetectionMaskPNG, "sam3_cpp_get_detection_mask_png"},
|
||||||
|
{&CppFreeResults, "sam3_cpp_free_results"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, gosamLib, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &SAM3{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
59
backend/go/sam3-cpp/package.sh
Executable file
59
backend/go/sam3-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Script to copy the appropriate libraries based on architecture
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
# Create lib directory
|
||||||
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
|
cp -avf $CURDIR/libgosam3-*.so $CURDIR/package/
|
||||||
|
cp -avf $CURDIR/sam3-cpp $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
|
||||||
|
# Detect architecture and copy appropriate libraries
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
# x86_64 architecture
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
# ARM64 architecture
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ $(uname -s) = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Package GPU libraries based on BUILD_TYPE
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah $CURDIR/package/
|
||||||
|
ls -liah $CURDIR/package/lib/
|
||||||
52
backend/go/sam3-cpp/run.sh
Executable file
52
backend/go/sam3-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Get the absolute current dir where the script is located
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
cd /
|
||||||
|
|
||||||
|
echo "CPU info:"
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||||
|
grep -e "flags" /proc/cpuinfo | head -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LIBRARY="$CURDIR/libgosam3-fallback.so"
|
||||||
|
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX found OK"
|
||||||
|
if [ -e $CURDIR/libgosam3-avx.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgosam3-avx.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX2 found OK"
|
||||||
|
if [ -e $CURDIR/libgosam3-avx2.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgosam3-avx2.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check avx 512
|
||||||
|
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX512F found OK"
|
||||||
|
if [ -e $CURDIR/libgosam3-avx512.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgosam3-avx512.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
|
export SAM3_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
|
# If there is a lib/ld.so, use it
|
||||||
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
|
echo "Using lib/ld.so"
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/lib/ld.so $CURDIR/sam3-cpp "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/sam3-cpp "$@"
|
||||||
50
backend/go/sam3-cpp/test.sh
Executable file
50
backend/go/sam3-cpp/test.sh
Executable file
@@ -0,0 +1,50 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
echo "Running sam3-cpp backend tests..."
|
||||||
|
|
||||||
|
# The test requires a SAM model in GGML format.
|
||||||
|
# Uses EdgeTAM Q4_0 (~15MB) for fast CI testing.
|
||||||
|
SAM3_MODEL_DIR="${SAM3_MODEL_DIR:-$CURDIR/test-models}"
|
||||||
|
SAM3_MODEL_FILE="${SAM3_MODEL_FILE:-edgetam_q4_0.ggml}"
|
||||||
|
SAM3_MODEL_URL="${SAM3_MODEL_URL:-https://huggingface.co/PABannier/sam3.cpp/resolve/main/edgetam_q4_0.ggml}"
|
||||||
|
|
||||||
|
# Download model if not present
|
||||||
|
if [ ! -f "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" ]; then
|
||||||
|
echo "Downloading EdgeTAM Q4_0 model for testing..."
|
||||||
|
mkdir -p "$SAM3_MODEL_DIR"
|
||||||
|
curl -L -o "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" "$SAM3_MODEL_URL" --progress-bar
|
||||||
|
echo "Model downloaded."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create a test image (4x4 red pixel PNG) using base64
|
||||||
|
# This is a minimal valid PNG for testing the pipeline
|
||||||
|
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||||
|
mkdir -p "$TEST_IMAGE_DIR"
|
||||||
|
|
||||||
|
# Generate a simple test image using Python if available, otherwise use a pre-encoded one
|
||||||
|
if command -v python3 &> /dev/null; then
|
||||||
|
python3 -c "
|
||||||
|
import struct, zlib, base64
|
||||||
|
def create_png(width, height, r, g, b):
|
||||||
|
raw = b''
|
||||||
|
for y in range(height):
|
||||||
|
raw += b'\x00' # filter byte
|
||||||
|
for x in range(width):
|
||||||
|
raw += bytes([r, g, b])
|
||||||
|
def chunk(ctype, data):
|
||||||
|
c = ctype + data
|
||||||
|
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
|
||||||
|
ihdr = struct.pack('>IIBBBBB', width, height, 8, 2, 0, 0, 0)
|
||||||
|
return b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', zlib.compress(raw)) + chunk(b'IEND', b'')
|
||||||
|
with open('$TEST_IMAGE_DIR/test.png', 'wb') as f:
|
||||||
|
f.write(create_png(64, 64, 255, 0, 0))
|
||||||
|
"
|
||||||
|
echo "Test image created."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "sam3-cpp test setup complete."
|
||||||
|
echo "Model: $SAM3_MODEL_DIR/$SAM3_MODEL_FILE"
|
||||||
|
echo "Note: Full integration tests run via the LocalAI test-extra target."
|
||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=87ecb95cbc65dc8e58e3d88f4f4a59a0939796f5
|
STABLEDIFFUSION_GGML_VERSION?=e8323cabb0e4511ba18a50b1cb34cf1f87fc71ef
|
||||||
|
|
||||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,31 @@
|
|||||||
nvidia-cuda-13: "cuda13-rfdetr"
|
nvidia-cuda-13: "cuda13-rfdetr"
|
||||||
nvidia-cuda-12: "cuda12-rfdetr"
|
nvidia-cuda-12: "cuda12-rfdetr"
|
||||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
||||||
|
- &sam3cpp
|
||||||
|
name: "sam3-cpp"
|
||||||
|
alias: "sam3-cpp"
|
||||||
|
license: mit
|
||||||
|
description: |
|
||||||
|
Segment Anything Model (SAM 3/2/EdgeTAM) in C/C++ using GGML.
|
||||||
|
Supports text-prompted and point/box-prompted image segmentation.
|
||||||
|
urls:
|
||||||
|
- https://github.com/PABannier/sam3.cpp
|
||||||
|
tags:
|
||||||
|
- image-segmentation
|
||||||
|
- object-detection
|
||||||
|
- sam3
|
||||||
|
- gpu
|
||||||
|
- cpu
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-sam3-cpp"
|
||||||
|
nvidia: "cuda12-sam3-cpp"
|
||||||
|
nvidia-cuda-12: "cuda12-sam3-cpp"
|
||||||
|
nvidia-cuda-13: "cuda13-sam3-cpp"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
intel: "intel-sycl-f32-sam3-cpp"
|
||||||
|
vulkan: "vulkan-sam3-cpp"
|
||||||
- &vllm
|
- &vllm
|
||||||
name: "vllm"
|
name: "vllm"
|
||||||
license: apache-2.0
|
license: apache-2.0
|
||||||
@@ -400,12 +425,15 @@
|
|||||||
license: MIT
|
license: MIT
|
||||||
name: "faster-whisper"
|
name: "faster-whisper"
|
||||||
capabilities:
|
capabilities:
|
||||||
|
default: "cpu-faster-whisper"
|
||||||
nvidia: "cuda12-faster-whisper"
|
nvidia: "cuda12-faster-whisper"
|
||||||
intel: "intel-faster-whisper"
|
intel: "intel-faster-whisper"
|
||||||
amd: "rocm-faster-whisper"
|
amd: "rocm-faster-whisper"
|
||||||
metal: "metal-faster-whisper"
|
metal: "metal-faster-whisper"
|
||||||
nvidia-cuda-13: "cuda13-faster-whisper"
|
nvidia-cuda-13: "cuda13-faster-whisper"
|
||||||
nvidia-cuda-12: "cuda12-faster-whisper"
|
nvidia-cuda-12: "cuda12-faster-whisper"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper"
|
||||||
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-faster-whisper"
|
||||||
- &moonshine
|
- &moonshine
|
||||||
description: |
|
description: |
|
||||||
Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
|
Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
|
||||||
@@ -438,6 +466,7 @@
|
|||||||
- whisperx
|
- whisperx
|
||||||
license: BSD-4-Clause
|
license: BSD-4-Clause
|
||||||
name: "whisperx"
|
name: "whisperx"
|
||||||
|
alias: "whisperx"
|
||||||
capabilities:
|
capabilities:
|
||||||
nvidia: "cuda12-whisperx"
|
nvidia: "cuda12-whisperx"
|
||||||
amd: "rocm-whisperx"
|
amd: "rocm-whisperx"
|
||||||
@@ -445,6 +474,8 @@
|
|||||||
default: "cpu-whisperx"
|
default: "cpu-whisperx"
|
||||||
nvidia-cuda-13: "cuda13-whisperx"
|
nvidia-cuda-13: "cuda13-whisperx"
|
||||||
nvidia-cuda-12: "cuda12-whisperx"
|
nvidia-cuda-12: "cuda12-whisperx"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-whisperx"
|
||||||
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisperx"
|
||||||
- &kokoro
|
- &kokoro
|
||||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||||
description: |
|
description: |
|
||||||
@@ -468,6 +499,26 @@
|
|||||||
nvidia-cuda-13: "cuda13-kokoro"
|
nvidia-cuda-13: "cuda13-kokoro"
|
||||||
nvidia-cuda-12: "cuda12-kokoro"
|
nvidia-cuda-12: "cuda12-kokoro"
|
||||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
|
||||||
|
- &kokoros
|
||||||
|
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||||
|
description: |
|
||||||
|
Kokoros is a pure Rust TTS backend using the Kokoro ONNX model (82M parameters).
|
||||||
|
It provides fast, high-quality text-to-speech with streaming support, built on
|
||||||
|
ONNX Runtime for efficient CPU inference. Supports English, Japanese, Mandarin
|
||||||
|
Chinese, and German.
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/hexgrad/Kokoro-82M
|
||||||
|
- https://github.com/lucasjinreal/Kokoros
|
||||||
|
tags:
|
||||||
|
- text-to-speech
|
||||||
|
- TTS
|
||||||
|
- Rust
|
||||||
|
- ONNX
|
||||||
|
license: apache-2.0
|
||||||
|
alias: "kokoros"
|
||||||
|
name: "kokoros"
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-kokoros"
|
||||||
- &coqui
|
- &coqui
|
||||||
urls:
|
urls:
|
||||||
- https://github.com/idiap/coqui-ai-TTS
|
- https://github.com/idiap/coqui-ai-TTS
|
||||||
@@ -1602,6 +1653,89 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-rfdetr
|
- localai/localai-backends:master-metal-darwin-arm64-rfdetr
|
||||||
|
## sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "sam3-cpp-development"
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-sam3-cpp-development"
|
||||||
|
nvidia: "cuda12-sam3-cpp-development"
|
||||||
|
nvidia-cuda-12: "cuda12-sam3-cpp-development"
|
||||||
|
nvidia-cuda-13: "cuda13-sam3-cpp-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||||
|
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||||
|
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||||
|
intel: "intel-sycl-f32-sam3-cpp-development"
|
||||||
|
vulkan: "vulkan-sam3-cpp-development"
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cpu-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cpu-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda12-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda12-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-nvidia-cuda-12-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda13-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda13-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-nvidia-cuda-13-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-nvidia-l4t-arm64-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-nvidia-l4t-arm64-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "intel-sycl-f32-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-intel-sycl-f32-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "intel-sycl-f32-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-intel-sycl-f32-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "vulkan-sam3-cpp"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-vulkan-sam3-cpp
|
||||||
|
- !!merge <<: *sam3cpp
|
||||||
|
name: "vulkan-sam3-cpp-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
|
||||||
## Rerankers
|
## Rerankers
|
||||||
- !!merge <<: *rerankers
|
- !!merge <<: *rerankers
|
||||||
name: "rerankers-development"
|
name: "rerankers-development"
|
||||||
@@ -2042,15 +2176,32 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-kokoro
|
- localai/localai-backends:master-metal-darwin-arm64-kokoro
|
||||||
|
## kokoros (Rust)
|
||||||
|
- !!merge <<: *kokoros
|
||||||
|
name: "kokoros-development"
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-kokoros-development"
|
||||||
|
- !!merge <<: *kokoros
|
||||||
|
name: "cpu-kokoros"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-kokoros"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-kokoros
|
||||||
|
- !!merge <<: *kokoros
|
||||||
|
name: "cpu-kokoros-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-kokoros"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-kokoros
|
||||||
## faster-whisper
|
## faster-whisper
|
||||||
- !!merge <<: *faster-whisper
|
- !!merge <<: *faster-whisper
|
||||||
name: "faster-whisper-development"
|
name: "faster-whisper-development"
|
||||||
capabilities:
|
capabilities:
|
||||||
|
default: "cpu-faster-whisper-development"
|
||||||
nvidia: "cuda12-faster-whisper-development"
|
nvidia: "cuda12-faster-whisper-development"
|
||||||
intel: "intel-faster-whisper-development"
|
intel: "intel-faster-whisper-development"
|
||||||
amd: "rocm-faster-whisper-development"
|
amd: "rocm-faster-whisper-development"
|
||||||
metal: "metal-faster-whisper-development"
|
metal: "metal-faster-whisper-development"
|
||||||
nvidia-cuda-13: "cuda13-faster-whisper-development"
|
nvidia-cuda-13: "cuda13-faster-whisper-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper-development"
|
||||||
- !!merge <<: *faster-whisper
|
- !!merge <<: *faster-whisper
|
||||||
name: "cuda12-faster-whisper-development"
|
name: "cuda12-faster-whisper-development"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
|
||||||
@@ -2091,6 +2242,36 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-faster-whisper
|
- localai/localai-backends:master-metal-darwin-arm64-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "cuda12-faster-whisper"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "rocm-faster-whisper"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-rocm-hipblas-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "cpu-faster-whisper"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "cpu-faster-whisper-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "nvidia-l4t-arm64-faster-whisper"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-nvidia-l4t-faster-whisper
|
||||||
|
- !!merge <<: *faster-whisper
|
||||||
|
name: "nvidia-l4t-arm64-faster-whisper-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-whisper"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-nvidia-l4t-faster-whisper
|
||||||
## moonshine
|
## moonshine
|
||||||
- !!merge <<: *moonshine
|
- !!merge <<: *moonshine
|
||||||
name: "moonshine-development"
|
name: "moonshine-development"
|
||||||
@@ -2149,6 +2330,7 @@
|
|||||||
default: "cpu-whisperx-development"
|
default: "cpu-whisperx-development"
|
||||||
nvidia-cuda-13: "cuda13-whisperx-development"
|
nvidia-cuda-13: "cuda13-whisperx-development"
|
||||||
nvidia-cuda-12: "cuda12-whisperx-development"
|
nvidia-cuda-12: "cuda12-whisperx-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-whisperx-development"
|
||||||
- !!merge <<: *whisperx
|
- !!merge <<: *whisperx
|
||||||
name: "cpu-whisperx"
|
name: "cpu-whisperx"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx"
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx"
|
||||||
@@ -2199,6 +2381,16 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-whisperx
|
- localai/localai-backends:master-metal-darwin-arm64-whisperx
|
||||||
|
- !!merge <<: *whisperx
|
||||||
|
name: "nvidia-l4t-arm64-whisperx"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-whisperx"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-nvidia-l4t-whisperx
|
||||||
|
- !!merge <<: *whisperx
|
||||||
|
name: "nvidia-l4t-arm64-whisperx-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-whisperx"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-nvidia-l4t-whisperx
|
||||||
## coqui
|
## coqui
|
||||||
|
|
||||||
- !!merge <<: *coqui
|
- !!merge <<: *coqui
|
||||||
|
|||||||
@@ -16,4 +16,14 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
|||||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||||
|
PYTHON_VERSION="3.12"
|
||||||
|
PYTHON_PATCH="12"
|
||||||
|
PY_STANDALONE_TAG="20251120"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||||
|
USE_PIP=true
|
||||||
|
fi
|
||||||
|
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|||||||
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||||
|
torch
|
||||||
|
faster-whisper
|
||||||
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||||
|
torch
|
||||||
|
faster-whisper
|
||||||
@@ -147,7 +147,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.language and request.language.strip():
|
if request.language and request.language.strip():
|
||||||
language = request.language.strip()
|
language = request.language.strip()
|
||||||
|
|
||||||
results = self.model.transcribe(audio=audio_path, language=language)
|
context = ""
|
||||||
|
if request.prompt and request.prompt.strip():
|
||||||
|
context = request.prompt.strip()
|
||||||
|
|
||||||
|
results = self.model.transcribe(audio=audio_path, language=language, context=context)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||||
|
|||||||
@@ -8,8 +8,21 @@ else
|
|||||||
source $backend_dir/../common/libbackend.sh
|
source $backend_dir/../common/libbackend.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
PYTHON_VERSION="3.12"
|
||||||
|
PYTHON_PATCH="12"
|
||||||
|
PY_STANDALONE_TAG="20251120"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||||
|
USE_PIP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# --index-strategy is a uv-only flag; skip it when using pip
|
||||||
|
if [ "x${USE_PIP}" != "xtrue" ]; then
|
||||||
|
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|||||||
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||||
|
torch
|
||||||
|
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||||
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||||
|
torch
|
||||||
|
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||||
3
backend/rust/kokoros/.gitignore
vendored
Normal file
3
backend/rust/kokoros/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
/target/
|
||||||
|
/proto/
|
||||||
|
/package/
|
||||||
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
backend/rust/kokoros/Cargo.toml
Normal file
26
backend/rust/kokoros/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
[package]
|
||||||
|
name = "kokoros-grpc"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "kokoros-grpc"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
kokoros = { path = "sources/Kokoros/kokoros" }
|
||||||
|
|
||||||
|
tonic = "0.13"
|
||||||
|
prost = "0.13"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.13"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["cpu"]
|
||||||
|
cpu = ["kokoros/cpu"]
|
||||||
25
backend/rust/kokoros/Makefile
Normal file
25
backend/rust/kokoros/Makefile
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
CURRENT_DIR=$(abspath ./)
|
||||||
|
|
||||||
|
.PHONY: kokoros-grpc
|
||||||
|
kokoros-grpc:
|
||||||
|
mkdir -p $(CURRENT_DIR)/proto
|
||||||
|
cp $(CURRENT_DIR)/../../backend.proto $(CURRENT_DIR)/proto/backend.proto
|
||||||
|
cd $(CURRENT_DIR) && \
|
||||||
|
BACKEND_PROTO_PATH=$(CURRENT_DIR)/proto/backend.proto \
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
.PHONY: package
|
||||||
|
package:
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: kokoros-grpc
|
||||||
|
cd $(CURRENT_DIR) && cargo test
|
||||||
|
|
||||||
|
.PHONY: build
|
||||||
|
build: kokoros-grpc package
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean:
|
||||||
|
cargo clean
|
||||||
|
rm -rf package proto
|
||||||
15
backend/rust/kokoros/build.rs
Normal file
15
backend/rust/kokoros/build.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let proto_path = std::env::var("BACKEND_PROTO_PATH")
|
||||||
|
.unwrap_or_else(|_| "proto/backend.proto".to_string());
|
||||||
|
|
||||||
|
let proto_dir = std::path::Path::new(&proto_path)
|
||||||
|
.parent()
|
||||||
|
.unwrap_or(std::path::Path::new("."));
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_server(true)
|
||||||
|
.build_client(false)
|
||||||
|
.compile_protos(&[&proto_path], &[proto_dir])?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
42
backend/rust/kokoros/package.sh
Normal file
42
backend/rust/kokoros/package.sh
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
|
# Copy the binary and run script
|
||||||
|
cp -avf $CURDIR/target/release/kokoros-grpc $CURDIR/package/
|
||||||
|
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
chmod +x $CURDIR/package/run.sh
|
||||||
|
|
||||||
|
# Copy espeak-ng data
|
||||||
|
if [ -d "/usr/share/espeak-ng-data" ]; then
|
||||||
|
cp -rf /usr/share/espeak-ng-data $CURDIR/package/
|
||||||
|
elif [ -d "/usr/lib/x86_64-linux-gnu/espeak-ng-data" ]; then
|
||||||
|
cp -rf /usr/lib/x86_64-linux-gnu/espeak-ng-data $CURDIR/package/
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Bundle all dynamic library dependencies
|
||||||
|
echo "Bundling dynamic library dependencies..."
|
||||||
|
ldd $CURDIR/target/release/kokoros-grpc | grep "=>" | awk '{print $3}' | while read lib; do
|
||||||
|
if [ -n "$lib" ] && [ -f "$lib" ]; then
|
||||||
|
cp -avfL "$lib" $CURDIR/package/lib/
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Copy CA certificates for HTTPS (needed for model auto-download)
|
||||||
|
if [ -d "/etc/ssl/certs" ]; then
|
||||||
|
mkdir -p $CURDIR/package/etc/ssl
|
||||||
|
cp -rf /etc/ssl/certs $CURDIR/package/etc/ssl/
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Copy the dynamic linker
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah $CURDIR/package/
|
||||||
|
ls -liah $CURDIR/package/lib/
|
||||||
23
backend/rust/kokoros/run.sh
Executable file
23
backend/rust/kokoros/run.sh
Executable file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$CURDIR/lib:${LD_LIBRARY_PATH:-}
|
||||||
|
|
||||||
|
# SSL certificates for model auto-download
|
||||||
|
if [ -d "$CURDIR/etc/ssl/certs" ]; then
|
||||||
|
export SSL_CERT_DIR=$CURDIR/etc/ssl/certs
|
||||||
|
fi
|
||||||
|
|
||||||
|
# espeak-ng data directory
|
||||||
|
if [ -d "$CURDIR/espeak-ng-data" ]; then
|
||||||
|
export ESPEAK_NG_DATA=$CURDIR/espeak-ng-data
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use bundled ld.so if present (portability)
|
||||||
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
|
exec $CURDIR/lib/ld.so $CURDIR/kokoros-grpc "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec $CURDIR/kokoros-grpc "$@"
|
||||||
1
backend/rust/kokoros/sources/Kokoros
Submodule
1
backend/rust/kokoros/sources/Kokoros
Submodule
Submodule backend/rust/kokoros/sources/Kokoros added at 7089168f0c
26
backend/rust/kokoros/src/auth.rs
Normal file
26
backend/rust/kokoros/src/auth.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use tonic::{Request, Status};
|
||||||
|
|
||||||
|
/// Returns an interceptor function if LOCALAI_GRPC_AUTH_TOKEN is set.
|
||||||
|
pub fn make_auth_interceptor(
|
||||||
|
) -> Option<impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone> {
|
||||||
|
let token = std::env::var("LOCALAI_GRPC_AUTH_TOKEN").ok()?;
|
||||||
|
if token.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let expected = format!("Bearer {}", token);
|
||||||
|
Some(
|
||||||
|
move |req: Request<()>| -> Result<Request<()>, Status> {
|
||||||
|
let meta = req.metadata();
|
||||||
|
match meta.get("authorization") {
|
||||||
|
Some(val) => {
|
||||||
|
if val.as_bytes() == expected.as_bytes() {
|
||||||
|
Ok(req)
|
||||||
|
} else {
|
||||||
|
Err(Status::unauthenticated("invalid token"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Err(Status::unauthenticated("missing authorization")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
53
backend/rust/kokoros/src/main.rs
Normal file
53
backend/rust/kokoros/src/main.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use tonic::transport::Server;
|
||||||
|
|
||||||
|
mod auth;
|
||||||
|
mod service;
|
||||||
|
|
||||||
|
pub mod backend {
|
||||||
|
tonic::include_proto!("backend");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "kokoros-grpc")]
|
||||||
|
struct Cli {
|
||||||
|
/// gRPC listen address (host:port)
|
||||||
|
#[arg(long, default_value = "localhost:50051")]
|
||||||
|
addr: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_ansi(false)
|
||||||
|
.without_time()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let cli = Cli::parse();
|
||||||
|
let addr = cli.addr.parse()?;
|
||||||
|
|
||||||
|
tracing::info!("Starting kokoros gRPC server on {}", addr);
|
||||||
|
|
||||||
|
let mut builder = Server::builder();
|
||||||
|
|
||||||
|
if let Some(interceptor) = auth::make_auth_interceptor() {
|
||||||
|
tracing::info!("Bearer token authentication enabled");
|
||||||
|
let svc = backend::backend_server::BackendServer::with_interceptor(
|
||||||
|
service::KokorosService::default(),
|
||||||
|
interceptor,
|
||||||
|
);
|
||||||
|
builder.add_service(svc).serve(addr).await?;
|
||||||
|
} else {
|
||||||
|
let svc = backend::backend_server::BackendServer::new(service::KokorosService::default())
|
||||||
|
.max_decoding_message_size(50 * 1024 * 1024)
|
||||||
|
.max_encoding_message_size(50 * 1024 * 1024);
|
||||||
|
builder.add_service(svc).serve(addr).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
652
backend/rust/kokoros/src/service.rs
Normal file
652
backend/rust/kokoros/src/service.rs
Normal file
@@ -0,0 +1,652 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tonic::{Request, Response, Status};
|
||||||
|
|
||||||
|
use kokoros::tts::koko::TTSKoko;
|
||||||
|
|
||||||
|
use crate::backend;
|
||||||
|
use crate::backend::backend_server::Backend;
|
||||||
|
|
||||||
|
/// Write f32 samples as a standard 44-byte PCM 16-bit WAV file.
|
||||||
|
/// LocalAI's audio pipeline assumes this exact header layout.
|
||||||
|
fn write_pcm16_wav(
|
||||||
|
path: &str,
|
||||||
|
samples: &[f32],
|
||||||
|
sample_rate: u32,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
let num_samples = samples.len() as u32;
|
||||||
|
let data_size = num_samples * 2; // 16-bit = 2 bytes per sample
|
||||||
|
let file_size = 36 + data_size;
|
||||||
|
|
||||||
|
let mut f = File::create(path)?;
|
||||||
|
|
||||||
|
// RIFF header
|
||||||
|
f.write_all(b"RIFF")?;
|
||||||
|
f.write_all(&file_size.to_le_bytes())?;
|
||||||
|
f.write_all(b"WAVE")?;
|
||||||
|
|
||||||
|
// fmt chunk — standard 16-byte PCM format
|
||||||
|
f.write_all(b"fmt ")?;
|
||||||
|
f.write_all(&16u32.to_le_bytes())?; // chunk size
|
||||||
|
f.write_all(&1u16.to_le_bytes())?; // audio format = PCM
|
||||||
|
f.write_all(&1u16.to_le_bytes())?; // channels = mono
|
||||||
|
f.write_all(&sample_rate.to_le_bytes())?;
|
||||||
|
f.write_all(&(sample_rate * 2).to_le_bytes())?; // byte rate
|
||||||
|
f.write_all(&2u16.to_le_bytes())?; // block align
|
||||||
|
f.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||||
|
|
||||||
|
// data chunk
|
||||||
|
f.write_all(b"data")?;
|
||||||
|
f.write_all(&data_size.to_le_bytes())?;
|
||||||
|
|
||||||
|
for &s in samples {
|
||||||
|
let clamped = s.clamp(-1.0, 1.0);
|
||||||
|
let pcm = (clamped * 32767.0) as i16;
|
||||||
|
f.write_all(&pcm.to_le_bytes())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KokorosService {
|
||||||
|
tts: Arc<TokioMutex<Option<TTSKoko>>>,
|
||||||
|
language: Arc<Mutex<String>>,
|
||||||
|
speed: Arc<Mutex<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KokorosService {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
tts: Arc::new(TokioMutex::new(None)),
|
||||||
|
language: Arc::new(Mutex::new("en-us".to_string())),
|
||||||
|
speed: Arc::new(Mutex::new(1.0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tonic::async_trait]
|
||||||
|
impl Backend for KokorosService {
|
||||||
|
async fn health(
|
||||||
|
&self,
|
||||||
|
_req: Request<backend::HealthMessage>,
|
||||||
|
) -> Result<Response<backend::Reply>, Status> {
|
||||||
|
Ok(Response::new(backend::Reply {
|
||||||
|
message: b"OK".to_vec(),
|
||||||
|
..Default::default()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_model(
|
||||||
|
&self,
|
||||||
|
req: Request<backend::ModelOptions>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
let opts = req.into_inner();
|
||||||
|
|
||||||
|
// Model path: join ModelPath + Model, or just Model
|
||||||
|
let model_path = if !opts.model_path.is_empty() && !opts.model.is_empty() {
|
||||||
|
format!("{}/{}", opts.model_path, opts.model)
|
||||||
|
} else if !opts.model.is_empty() {
|
||||||
|
opts.model.clone()
|
||||||
|
} else {
|
||||||
|
"checkpoints/kokoro-v1.0.onnx".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Voices data path from AudioPath, or derive from model dir
|
||||||
|
let voices_path = if !opts.audio_path.is_empty() {
|
||||||
|
opts.audio_path.clone()
|
||||||
|
} else {
|
||||||
|
let model_dir = std::path::Path::new(&model_path)
|
||||||
|
.parent()
|
||||||
|
.map(|p| p.to_string_lossy().to_string())
|
||||||
|
.unwrap_or_else(|| ".".to_string());
|
||||||
|
format!("{}/voices-v1.0.bin", model_dir)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse options (key:value pairs)
|
||||||
|
for opt in &opts.options {
|
||||||
|
if let Some((key, value)) = opt.split_once(':') {
|
||||||
|
match key {
|
||||||
|
"lang_code" => *self.language.lock().unwrap() = value.to_string(),
|
||||||
|
"speed" => {
|
||||||
|
if let Ok(s) = value.parse::<f32>() {
|
||||||
|
*self.speed.lock().unwrap() = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Loading Kokoros model from: {}", model_path);
|
||||||
|
tracing::info!("Loading voices from: {}", voices_path);
|
||||||
|
tracing::info!("Language: {}", self.language.lock().unwrap());
|
||||||
|
|
||||||
|
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||||
|
*self.tts.lock().await = Some(tts);
|
||||||
|
|
||||||
|
tracing::info!("Kokoros TTS model loaded successfully");
|
||||||
|
Ok(Response::new(backend::Result {
|
||||||
|
success: true,
|
||||||
|
message: "Kokoros TTS model loaded".into(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tts(
|
||||||
|
&self,
|
||||||
|
req: Request<backend::TtsRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
let req = req.into_inner();
|
||||||
|
let tts_guard = self.tts.lock().await;
|
||||||
|
let tts = tts_guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?;
|
||||||
|
|
||||||
|
let voice = if req.voice.is_empty() {
|
||||||
|
"af_heart"
|
||||||
|
} else {
|
||||||
|
&req.voice
|
||||||
|
};
|
||||||
|
let lang = req
|
||||||
|
.language
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||||
|
let speed = *self.speed.lock().unwrap();
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
text = req.text,
|
||||||
|
voice = voice,
|
||||||
|
lang = lang.as_str(),
|
||||||
|
dst = req.dst,
|
||||||
|
"TTS request received"
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
match tts.tts_raw_audio(&req.text, &lang, voice, speed, None, None, None, None) {
|
||||||
|
Ok(samples) => {
|
||||||
|
let duration_secs = samples.len() as f64 / 24000.0;
|
||||||
|
tracing::info!(
|
||||||
|
num_samples = samples.len(),
|
||||||
|
audio_duration = format!("{:.2}s", duration_secs),
|
||||||
|
inference_time = format!("{:.2}s", start.elapsed().as_secs_f64()),
|
||||||
|
dst = req.dst,
|
||||||
|
"TTS inference complete"
|
||||||
|
);
|
||||||
|
if let Err(e) = write_pcm16_wav(&req.dst, &samples, 24000) {
|
||||||
|
tracing::error!("Failed to write WAV to {}: {}", req.dst, e);
|
||||||
|
return Ok(Response::new(backend::Result {
|
||||||
|
success: false,
|
||||||
|
message: format!("Failed to write WAV: {}", e),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Ok(Response::new(backend::Result {
|
||||||
|
success: true,
|
||||||
|
message: String::new(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("TTS error: {}", e);
|
||||||
|
Ok(Response::new(backend::Result {
|
||||||
|
success: false,
|
||||||
|
message: format!("TTS error: {}", e),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TTSStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||||
|
|
||||||
|
async fn tts_stream(
|
||||||
|
&self,
|
||||||
|
req: Request<backend::TtsRequest>,
|
||||||
|
) -> Result<Response<Self::TTSStreamStream>, Status> {
|
||||||
|
let req = req.into_inner();
|
||||||
|
let tts_guard = self.tts.lock().await;
|
||||||
|
let tts = tts_guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let voice = if req.voice.is_empty() {
|
||||||
|
"af_heart".to_string()
|
||||||
|
} else {
|
||||||
|
req.voice
|
||||||
|
};
|
||||||
|
let lang = req
|
||||||
|
.language
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||||
|
let speed = *self.speed.lock().unwrap();
|
||||||
|
let text = req.text;
|
||||||
|
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel(32);
|
||||||
|
|
||||||
|
// Send sample rate info as first message
|
||||||
|
let tx_clone = tx.clone();
|
||||||
|
let _ = tx_clone
|
||||||
|
.send(Ok(backend::Reply {
|
||||||
|
message: br#"{"sample_rate":24000}"#.to_vec(),
|
||||||
|
..Default::default()
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let result = tts.tts_raw_audio_streaming(
|
||||||
|
&text,
|
||||||
|
&lang,
|
||||||
|
&voice,
|
||||||
|
speed,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
|audio_chunk: Vec<f32>| -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Convert f32 PCM to 16-bit PCM bytes (what LocalAI expects for streaming)
|
||||||
|
let bytes: Vec<u8> = audio_chunk
|
||||||
|
.iter()
|
||||||
|
.flat_map(|&s| {
|
||||||
|
let clamped = s.clamp(-1.0, 1.0);
|
||||||
|
let i16_val = (clamped * 32767.0) as i16;
|
||||||
|
i16_val.to_le_bytes()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
tx.blocking_send(Ok(backend::Reply {
|
||||||
|
audio: bytes,
|
||||||
|
..Default::default()
|
||||||
|
}))
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::error!("TTSStream error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Response::new(ReceiverStream::new(rx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn status(
|
||||||
|
&self,
|
||||||
|
_req: Request<backend::HealthMessage>,
|
||||||
|
) -> Result<Response<backend::StatusResponse>, Status> {
|
||||||
|
let tts = self.tts.lock().await;
|
||||||
|
let state = if tts.is_some() {
|
||||||
|
backend::status_response::State::Ready as i32
|
||||||
|
} else {
|
||||||
|
backend::status_response::State::Uninitialized as i32
|
||||||
|
};
|
||||||
|
Ok(Response::new(backend::StatusResponse {
|
||||||
|
state,
|
||||||
|
memory: None,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn free(
|
||||||
|
&self,
|
||||||
|
_req: Request<backend::HealthMessage>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
*self.tts.lock().await = None;
|
||||||
|
Ok(Response::new(backend::Result {
|
||||||
|
success: true,
|
||||||
|
message: "Model freed".into(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Unimplemented RPCs ---
|
||||||
|
|
||||||
|
async fn predict(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::PredictOptions>,
|
||||||
|
) -> Result<Response<backend::Reply>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type PredictStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||||
|
|
||||||
|
async fn predict_stream(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::PredictOptions>,
|
||||||
|
) -> Result<Response<Self::PredictStreamStream>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embedding(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::PredictOptions>,
|
||||||
|
) -> Result<Response<backend::EmbeddingResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_image(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::GenerateImageRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_video(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::GenerateVideoRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn audio_transcription(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::TranscriptRequest>,
|
||||||
|
) -> Result<Response<backend::TranscriptResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sound_generation(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::SoundGenerationRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tokenize_string(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::PredictOptions>,
|
||||||
|
) -> Result<Response<backend::TokenizationResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn detect(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::DetectOptions>,
|
||||||
|
) -> Result<Response<backend::DetectResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stores_set(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::StoresSetOptions>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stores_delete(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::StoresDeleteOptions>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stores_get(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::StoresGetOptions>,
|
||||||
|
) -> Result<Response<backend::StoresGetResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stores_find(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::StoresFindOptions>,
|
||||||
|
) -> Result<Response<backend::StoresFindResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn rerank(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::RerankRequest>,
|
||||||
|
) -> Result<Response<backend::RerankResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_metrics(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::MetricsRequest>,
|
||||||
|
) -> Result<Response<backend::MetricsResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn vad(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::VadRequest>,
|
||||||
|
) -> Result<Response<backend::VadResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn audio_encode(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::AudioEncodeRequest>,
|
||||||
|
) -> Result<Response<backend::AudioEncodeResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn audio_decode(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::AudioDecodeRequest>,
|
||||||
|
) -> Result<Response<backend::AudioDecodeResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_metadata(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::ModelOptions>,
|
||||||
|
) -> Result<Response<backend::ModelMetadataResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_fine_tune(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::FineTuneRequest>,
|
||||||
|
) -> Result<Response<backend::FineTuneJobResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type FineTuneProgressStream = ReceiverStream<Result<backend::FineTuneProgressUpdate, Status>>;
|
||||||
|
|
||||||
|
async fn fine_tune_progress(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::FineTuneProgressRequest>,
|
||||||
|
) -> Result<Response<Self::FineTuneProgressStream>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop_fine_tune(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::FineTuneStopRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_checkpoints(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::ListCheckpointsRequest>,
|
||||||
|
) -> Result<Response<backend::ListCheckpointsResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn export_model(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::ExportModelRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_quantization(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::QuantizationRequest>,
|
||||||
|
) -> Result<Response<backend::QuantizationJobResult>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type QuantizationProgressStream =
|
||||||
|
ReceiverStream<Result<backend::QuantizationProgressUpdate, Status>>;
|
||||||
|
|
||||||
|
async fn quantization_progress(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::QuantizationProgressRequest>,
|
||||||
|
) -> Result<Response<Self::QuantizationProgressStream>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop_quantization(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::QuantizationStopRequest>,
|
||||||
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wav_header_is_standard_pcm16() {
|
||||||
|
let samples = vec![0.0f32, 0.5, -0.5, 1.0, -1.0];
|
||||||
|
let path = std::env::temp_dir().join("kokoros_test.wav");
|
||||||
|
let path_str = path.to_str().unwrap();
|
||||||
|
|
||||||
|
write_pcm16_wav(path_str, &samples, 24000).unwrap();
|
||||||
|
|
||||||
|
let data = std::fs::read(&path).unwrap();
|
||||||
|
std::fs::remove_file(&path).unwrap();
|
||||||
|
|
||||||
|
// Must be exactly 44-byte header + data
|
||||||
|
assert_eq!(data.len(), 44 + samples.len() * 2);
|
||||||
|
|
||||||
|
// RIFF header
|
||||||
|
assert_eq!(&data[0..4], b"RIFF");
|
||||||
|
assert_eq!(&data[8..12], b"WAVE");
|
||||||
|
|
||||||
|
// fmt chunk: 16 bytes, format=1 (PCM), channels=1, 16-bit
|
||||||
|
assert_eq!(&data[12..16], b"fmt ");
|
||||||
|
assert_eq!(u32::from_le_bytes(data[16..20].try_into().unwrap()), 16); // chunk size
|
||||||
|
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM format
|
||||||
|
assert_eq!(u16::from_le_bytes(data[22..24].try_into().unwrap()), 1); // mono
|
||||||
|
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||||
|
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // bits per sample
|
||||||
|
|
||||||
|
// data chunk
|
||||||
|
assert_eq!(&data[36..40], b"data");
|
||||||
|
assert_eq!(
|
||||||
|
u32::from_le_bytes(data[40..44].try_into().unwrap()),
|
||||||
|
(samples.len() * 2) as u32
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify sample values: 0.5 -> 16383, -0.5 -> -16383, 1.0 -> 32767, -1.0 -> -32767
|
||||||
|
let s1 = i16::from_le_bytes(data[46..48].try_into().unwrap());
|
||||||
|
assert_eq!(s1, 16383); // 0.5 * 32767
|
||||||
|
let s3 = i16::from_le_bytes(data[50..52].try_into().unwrap());
|
||||||
|
assert_eq!(s3, 32767); // 1.0 clamped
|
||||||
|
let s4 = i16::from_le_bytes(data[52..54].try_into().unwrap());
|
||||||
|
assert_eq!(s4, -32767); // -1.0 clamped
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Integration test: runs actual TTS inference and validates the output audio.
|
||||||
|
/// Skipped unless KOKOROS_MODEL_PATH is set to a directory containing
|
||||||
|
/// kokoro-v1.0.onnx and voices-v1.0.bin.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tts_produces_valid_speech() {
|
||||||
|
let model_dir = match std::env::var("KOKOROS_MODEL_PATH") {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => {
|
||||||
|
eprintln!("KOKOROS_MODEL_PATH not set, skipping integration test");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_path = format!("{}/kokoro-v1.0.onnx", model_dir);
|
||||||
|
let voices_path = format!("{}/voices-v1.0.bin", model_dir);
|
||||||
|
|
||||||
|
if !std::path::Path::new(&model_path).exists() {
|
||||||
|
eprintln!("Model file not found at {}, skipping", model_path);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||||
|
|
||||||
|
let input_text = "Hello world, this is a test of speech synthesis.";
|
||||||
|
let out_path = std::env::temp_dir().join("kokoros_integration_test.wav");
|
||||||
|
let out_str = out_path.to_str().unwrap();
|
||||||
|
|
||||||
|
let samples = tts
|
||||||
|
.tts_raw_audio(input_text, "en-us", "af_heart", 1.0, None, None, None, None)
|
||||||
|
.expect("tts_raw_audio failed");
|
||||||
|
|
||||||
|
write_pcm16_wav(out_str, &samples, 24000).unwrap();
|
||||||
|
|
||||||
|
let data = std::fs::read(&out_path).unwrap();
|
||||||
|
std::fs::remove_file(&out_path).unwrap();
|
||||||
|
|
||||||
|
// --- WAV header sanity ---
|
||||||
|
assert_eq!(&data[0..4], b"RIFF");
|
||||||
|
assert_eq!(&data[8..12], b"WAVE");
|
||||||
|
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM
|
||||||
|
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||||
|
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // 16-bit
|
||||||
|
|
||||||
|
let num_samples = samples.len();
|
||||||
|
let duration_secs = num_samples as f64 / 24000.0;
|
||||||
|
|
||||||
|
// --- Duration check ---
|
||||||
|
// ~10 words should produce roughly 2-8 seconds of speech
|
||||||
|
assert!(
|
||||||
|
duration_secs > 1.0,
|
||||||
|
"Audio too short: {:.2}s for {} words",
|
||||||
|
duration_secs,
|
||||||
|
input_text.split_whitespace().count()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
duration_secs < 15.0,
|
||||||
|
"Audio too long: {:.2}s for {} words",
|
||||||
|
duration_secs,
|
||||||
|
input_text.split_whitespace().count()
|
||||||
|
);
|
||||||
|
|
||||||
|
// --- Energy check: not silence ---
|
||||||
|
let rms = (samples.iter().map(|s| s * s).sum::<f32>() / num_samples as f32).sqrt();
|
||||||
|
assert!(
|
||||||
|
rms > 0.01,
|
||||||
|
"Audio is near-silence: RMS = {:.6}",
|
||||||
|
rms
|
||||||
|
);
|
||||||
|
|
||||||
|
// --- Not clipped/saturated: should have dynamic range ---
|
||||||
|
let max_abs = samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
|
||||||
|
assert!(
|
||||||
|
max_abs < 1.0,
|
||||||
|
"Audio is fully saturated (max |sample| = {:.4})",
|
||||||
|
max_abs
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
max_abs > 0.05,
|
||||||
|
"Audio has very low amplitude (max |sample| = {:.4})",
|
||||||
|
max_abs
|
||||||
|
);
|
||||||
|
|
||||||
|
// --- Speech-like spectral check ---
|
||||||
|
// Speech should have significant energy variation (not white noise or DC).
|
||||||
|
// Check that the signal has zero-crossings in a speech-like range (roughly
|
||||||
|
// 50-400 crossings per 24000 samples = 100-8000 Hz fundamental range).
|
||||||
|
let zero_crossings: usize = samples
|
||||||
|
.windows(2)
|
||||||
|
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
|
||||||
|
.count();
|
||||||
|
let crossings_per_sec = zero_crossings as f64 / duration_secs;
|
||||||
|
// White noise at 24kHz would have ~12000 crossings/sec.
|
||||||
|
// Speech is typically 100-4000 crossings/sec.
|
||||||
|
assert!(
|
||||||
|
crossings_per_sec < 10000.0,
|
||||||
|
"Too many zero crossings ({:.0}/s) — likely noise, not speech",
|
||||||
|
crossings_per_sec
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
crossings_per_sec > 50.0,
|
||||||
|
"Too few zero crossings ({:.0}/s) — likely DC or silence, not speech",
|
||||||
|
crossings_per_sec
|
||||||
|
);
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"Integration test passed: duration={:.2}s, rms={:.4}, max={:.4}, zero_crossings={:.0}/s",
|
||||||
|
duration_secs, rms, max_abs, crossings_per_sec
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,9 @@ type Application struct {
|
|||||||
|
|
||||||
// Distributed mode services (nil when not in distributed mode)
|
// Distributed mode services (nil when not in distributed mode)
|
||||||
distributed *DistributedServices
|
distributed *DistributedServices
|
||||||
|
|
||||||
|
// Upgrade checker (background service for detecting backend upgrades)
|
||||||
|
upgradeChecker *UpgradeChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||||
@@ -79,6 +82,19 @@ func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
|||||||
return a.agentJobService
|
return a.agentJobService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Application) UpgradeChecker() *UpgradeChecker {
|
||||||
|
return a.upgradeChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
// distributedDB returns the PostgreSQL database for distributed coordination,
|
||||||
|
// or nil in standalone mode.
|
||||||
|
func (a *Application) distributedDB() *gorm.DB {
|
||||||
|
if a.distributed != nil {
|
||||||
|
return a.authDB
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||||
return a.agentPoolService.Load()
|
return a.agentPoolService.Load()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -335,6 +335,9 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
|||||||
if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
|
if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
|
||||||
appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||||
}
|
}
|
||||||
|
if settings.AutoUpgradeBackends != nil {
|
||||||
|
appConfig.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||||
|
}
|
||||||
if settings.ApiKeys != nil {
|
if settings.ApiKeys != nil {
|
||||||
// API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
|
// API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
|
||||||
// If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
|
// If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
|
||||||
|
|||||||
@@ -231,6 +231,15 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
xlog.Error("error registering external backends", "error", err)
|
xlog.Error("error registering external backends", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start background upgrade checker for backends.
|
||||||
|
// In distributed mode, uses PostgreSQL advisory lock so only one frontend
|
||||||
|
// instance runs periodic checks (avoids duplicate upgrades across replicas).
|
||||||
|
if len(options.BackendGalleries) > 0 {
|
||||||
|
uc := NewUpgradeChecker(options, application.ModelLoader(), application.distributedDB())
|
||||||
|
application.upgradeChecker = uc
|
||||||
|
go uc.Run(options.Context)
|
||||||
|
}
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
if options.ConfigFile != "" {
|
||||||
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||||
xlog.Error("error loading config file", "error", err)
|
xlog.Error("error loading config file", "error", err)
|
||||||
|
|||||||
198
core/application/upgrade_checker.go
Normal file
198
core/application/upgrade_checker.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package application
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpgradeChecker periodically checks for backend upgrades and optionally
|
||||||
|
// auto-upgrades them. It caches the last check results for API queries.
|
||||||
|
//
|
||||||
|
// In standalone mode it runs a simple ticker loop.
|
||||||
|
// In distributed mode it uses a PostgreSQL advisory lock so that only one
|
||||||
|
// frontend instance performs periodic checks and auto-upgrades at a time.
|
||||||
|
type UpgradeChecker struct {
|
||||||
|
appConfig *config.ApplicationConfig
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
galleries []config.Gallery
|
||||||
|
systemState *system.SystemState
|
||||||
|
db *gorm.DB // non-nil in distributed mode
|
||||||
|
|
||||||
|
checkInterval time.Duration
|
||||||
|
stop chan struct{}
|
||||||
|
done chan struct{}
|
||||||
|
triggerCh chan struct{}
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
lastUpgrades map[string]gallery.UpgradeInfo
|
||||||
|
lastCheckTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpgradeChecker creates a new UpgradeChecker service.
|
||||||
|
// Pass db=nil for standalone mode, or a *gorm.DB for distributed mode
|
||||||
|
// (uses advisory locks so only one instance runs periodic checks).
|
||||||
|
func NewUpgradeChecker(appConfig *config.ApplicationConfig, ml *model.ModelLoader, db *gorm.DB) *UpgradeChecker {
|
||||||
|
return &UpgradeChecker{
|
||||||
|
appConfig: appConfig,
|
||||||
|
modelLoader: ml,
|
||||||
|
galleries: appConfig.BackendGalleries,
|
||||||
|
systemState: appConfig.SystemState,
|
||||||
|
db: db,
|
||||||
|
checkInterval: 6 * time.Hour,
|
||||||
|
stop: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
triggerCh: make(chan struct{}, 1),
|
||||||
|
lastUpgrades: make(map[string]gallery.UpgradeInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the upgrade checker loop. It waits 30 seconds after startup,
|
||||||
|
// performs an initial check, then re-checks every 6 hours.
|
||||||
|
//
|
||||||
|
// In distributed mode, periodic checks are guarded by a PostgreSQL advisory
|
||||||
|
// lock so only one frontend instance runs them. On-demand triggers (TriggerCheck)
|
||||||
|
// and the initial check always run locally for fast API response cache warming.
|
||||||
|
func (uc *UpgradeChecker) Run(ctx context.Context) {
|
||||||
|
defer close(uc.done)
|
||||||
|
|
||||||
|
// Initial delay: don't slow down startup
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-uc.stop:
|
||||||
|
return
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
}
|
||||||
|
|
||||||
|
// First check always runs locally (to warm the cache on this instance)
|
||||||
|
uc.runCheck(ctx)
|
||||||
|
|
||||||
|
if uc.db != nil {
|
||||||
|
// Distributed mode: use advisory lock for periodic checks.
|
||||||
|
// RunLeaderLoop ticks every checkInterval; only the lock holder executes.
|
||||||
|
go advisorylock.RunLeaderLoop(ctx, uc.db, advisorylock.KeyBackendUpgradeCheck, uc.checkInterval, func() {
|
||||||
|
uc.runCheck(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Still listen for on-demand triggers (from API / settings change)
|
||||||
|
// and stop signal — these run on every instance.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-uc.stop:
|
||||||
|
return
|
||||||
|
case <-uc.triggerCh:
|
||||||
|
uc.runCheck(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Standalone mode: simple ticker loop
|
||||||
|
ticker := time.NewTicker(uc.checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-uc.stop:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
uc.runCheck(ctx)
|
||||||
|
case <-uc.triggerCh:
|
||||||
|
uc.runCheck(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown stops the upgrade checker loop.
|
||||||
|
func (uc *UpgradeChecker) Shutdown() {
|
||||||
|
close(uc.stop)
|
||||||
|
<-uc.done
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerCheck forces an immediate upgrade check on this instance.
|
||||||
|
func (uc *UpgradeChecker) TriggerCheck() {
|
||||||
|
select {
|
||||||
|
case uc.triggerCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Already triggered, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvailableUpgrades returns the cached upgrade check results.
|
||||||
|
func (uc *UpgradeChecker) GetAvailableUpgrades() map[string]gallery.UpgradeInfo {
|
||||||
|
uc.mu.RLock()
|
||||||
|
defer uc.mu.RUnlock()
|
||||||
|
|
||||||
|
// Return a copy to avoid races
|
||||||
|
result := make(map[string]gallery.UpgradeInfo, len(uc.lastUpgrades))
|
||||||
|
for k, v := range uc.lastUpgrades {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||||
|
upgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState)
|
||||||
|
|
||||||
|
uc.mu.Lock()
|
||||||
|
uc.lastCheckTime = time.Now()
|
||||||
|
if err != nil {
|
||||||
|
xlog.Debug("Backend upgrade check failed", "error", err)
|
||||||
|
uc.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
uc.lastUpgrades = upgrades
|
||||||
|
uc.mu.Unlock()
|
||||||
|
|
||||||
|
if len(upgrades) == 0 {
|
||||||
|
xlog.Debug("All backends up to date")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log available upgrades
|
||||||
|
for name, info := range upgrades {
|
||||||
|
if info.AvailableVersion != "" {
|
||||||
|
xlog.Info("Backend upgrade available",
|
||||||
|
"backend", name,
|
||||||
|
"installed", info.InstalledVersion,
|
||||||
|
"available", info.AvailableVersion)
|
||||||
|
} else {
|
||||||
|
xlog.Info("Backend upgrade available (new build)",
|
||||||
|
"backend", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-upgrade if enabled
|
||||||
|
if uc.appConfig.AutoUpgradeBackends {
|
||||||
|
for name, info := range upgrades {
|
||||||
|
xlog.Info("Auto-upgrading backend", "backend", name,
|
||||||
|
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||||
|
if err := gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||||
|
uc.galleries, name, nil); err != nil {
|
||||||
|
xlog.Error("Failed to auto-upgrade backend",
|
||||||
|
"backend", name, "error", err)
|
||||||
|
} else {
|
||||||
|
xlog.Info("Backend upgraded successfully", "backend", name,
|
||||||
|
"version", info.AvailableVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Re-check to update cache after upgrades
|
||||||
|
if freshUpgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState); err == nil {
|
||||||
|
uc.mu.Lock()
|
||||||
|
uc.lastUpgrades = freshUpgrades
|
||||||
|
uc.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,27 @@ import (
|
|||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SyncPinnedModelsToWatchdog reads pinned status from all model configs and updates the watchdog
|
||||||
|
func (a *Application) SyncPinnedModelsToWatchdog() {
|
||||||
|
cl := a.ModelConfigLoader()
|
||||||
|
if cl == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wd := a.modelLoader.GetWatchDog()
|
||||||
|
if wd == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
configs := cl.GetAllModelsConfigs()
|
||||||
|
var pinned []string
|
||||||
|
for _, cfg := range configs {
|
||||||
|
if cfg.IsPinned() {
|
||||||
|
pinned = append(pinned, cfg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wd.SetPinnedModels(pinned)
|
||||||
|
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Application) StopWatchdog() error {
|
func (a *Application) StopWatchdog() error {
|
||||||
if a.watchdogStop != nil {
|
if a.watchdogStop != nil {
|
||||||
close(a.watchdogStop)
|
close(a.watchdogStop)
|
||||||
@@ -44,6 +65,9 @@ func (a *Application) startWatchdog() error {
|
|||||||
// Set the watchdog on the model loader
|
// Set the watchdog on the model loader
|
||||||
a.modelLoader.SetWatchDog(wd)
|
a.modelLoader.SetWatchDog(wd)
|
||||||
|
|
||||||
|
// Sync pinned models from config to the watchdog
|
||||||
|
a.SyncPinnedModelsToWatchdog()
|
||||||
|
|
||||||
// Start watchdog goroutine if any periodic checks are enabled
|
// Start watchdog goroutine if any periodic checks are enabled
|
||||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||||
// But memory reclaimer needs the Run() loop for periodic checking
|
// But memory reclaimer needs the Run() loop for periodic checking
|
||||||
@@ -124,5 +148,8 @@ func (a *Application) RestartWatchdog() error {
|
|||||||
newWD.RestoreState(oldState)
|
newWD.RestoreState(oldState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-sync pinned models after restart
|
||||||
|
a.SyncPinnedModelsToWatchdog()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ import (
|
|||||||
|
|
||||||
func Detection(
|
func Detection(
|
||||||
sourceFile string,
|
sourceFile string,
|
||||||
|
prompt string,
|
||||||
|
points []float32,
|
||||||
|
boxes []float32,
|
||||||
|
threshold float32,
|
||||||
loader *model.ModelLoader,
|
loader *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
modelConfig config.ModelConfig,
|
modelConfig config.ModelConfig,
|
||||||
@@ -35,7 +39,11 @@ func Detection(
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||||
Src: sourceFile,
|
Src: sourceFile,
|
||||||
|
Prompt: prompt,
|
||||||
|
Points: points,
|
||||||
|
Boxes: boxes,
|
||||||
|
Threshold: threshold,
|
||||||
})
|
})
|
||||||
|
|
||||||
if appConfig.EnableTracing {
|
if appConfig.EnableTracing {
|
||||||
|
|||||||
@@ -36,6 +36,27 @@ type TokenUsage struct {
|
|||||||
Completion int
|
Completion int
|
||||||
TimingPromptProcessing float64
|
TimingPromptProcessing float64
|
||||||
TimingTokenGeneration float64
|
TimingTokenGeneration float64
|
||||||
|
ChatDeltas []*proto.ChatDelta // per-chunk deltas from C++ autoparser (only set during streaming)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasChatDeltaContent returns true if any chat delta carries content or reasoning text.
|
||||||
|
// Used to decide whether to prefer C++ autoparser deltas over Go-side tag extraction.
|
||||||
|
func (t TokenUsage) HasChatDeltaContent() bool {
|
||||||
|
for _, d := range t.ChatDeltas {
|
||||||
|
if d.Content != "" || d.ReasoningContent != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatDeltaReasoningAndContent extracts accumulated reasoning and content from chat deltas.
|
||||||
|
func (t TokenUsage) ChatDeltaReasoningAndContent() (reasoning, content string) {
|
||||||
|
for _, d := range t.ChatDeltas {
|
||||||
|
content += d.Content
|
||||||
|
reasoning += d.ReasoningContent
|
||||||
|
}
|
||||||
|
return reasoning, content
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
||||||
@@ -171,6 +192,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
|||||||
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attach per-chunk chat deltas to tokenUsage so the callback can use them
|
||||||
|
tokenUsage.ChatDeltas = reply.ChatDeltas
|
||||||
|
|
||||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||||
if len(reply.Logprobs) > 0 {
|
if len(reply.Logprobs) > 0 {
|
||||||
var parsedLogprobs schema.Logprobs
|
var parsedLogprobs schema.Logprobs
|
||||||
@@ -200,6 +224,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
|||||||
if len(msg) == 0 {
|
if len(msg) == 0 {
|
||||||
tokenCallback("", tokenUsage)
|
tokenCallback("", tokenUsage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear per-chunk deltas so they don't leak to the next chunk
|
||||||
|
tokenUsage.ChatDeltas = nil
|
||||||
})
|
})
|
||||||
if len(allChatDeltas) > 0 {
|
if len(allChatDeltas) > 0 {
|
||||||
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
. "github.com/mudler/LocalAI/core/backend"
|
. "github.com/mudler/LocalAI/core/backend"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@@ -107,3 +108,111 @@ var _ = Describe("LLM tests", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var _ = Describe("TokenUsage ChatDelta helpers", func() {
|
||||||
|
Describe("HasChatDeltaContent", func() {
|
||||||
|
It("should return false when ChatDeltas is nil", func() {
|
||||||
|
usage := TokenUsage{}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return false when ChatDeltas is empty", func() {
|
||||||
|
usage := TokenUsage{ChatDeltas: []*pb.ChatDelta{}}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return false when all deltas have empty content and reasoning", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{Content: "", ReasoningContent: ""},
|
||||||
|
{Content: ""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return true when a delta has content", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{Content: "hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return true when a delta has reasoning content", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{ReasoningContent: "thinking..."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return true when a delta has both content and reasoning", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{Content: "hello", ReasoningContent: "thinking..."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("ChatDeltaReasoningAndContent", func() {
|
||||||
|
It("should return empty strings when ChatDeltas is nil", func() {
|
||||||
|
usage := TokenUsage{}
|
||||||
|
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||||
|
Expect(reasoning).To(BeEmpty())
|
||||||
|
Expect(content).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should concatenate content from multiple deltas", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " world"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||||
|
Expect(content).To(Equal("Hello world"))
|
||||||
|
Expect(reasoning).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should concatenate reasoning from multiple deltas", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{ReasoningContent: "step 1"},
|
||||||
|
{ReasoningContent: " step 2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||||
|
Expect(reasoning).To(Equal("step 1 step 2"))
|
||||||
|
Expect(content).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should separate reasoning and content from mixed deltas", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{ReasoningContent: "thinking"},
|
||||||
|
{Content: "answer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||||
|
Expect(reasoning).To(Equal("thinking"))
|
||||||
|
Expect(content).To(Equal("answer"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle deltas with both fields set", func() {
|
||||||
|
usage := TokenUsage{
|
||||||
|
ChatDeltas: []*pb.ChatDelta{
|
||||||
|
{Content: "a", ReasoningContent: "r1"},
|
||||||
|
{Content: "b", ReasoningContent: "r2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||||
|
Expect(reasoning).To(Equal("r1r2"))
|
||||||
|
Expect(content).To(Equal("ab"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -40,10 +40,17 @@ type BackendsUninstall struct {
|
|||||||
BackendsCMDFlags `embed:""`
|
BackendsCMDFlags `embed:""`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BackendsUpgrade struct {
|
||||||
|
BackendArgs []string `arg:"" optional:"" name:"backends" help:"Backend names to upgrade (empty = upgrade all)"`
|
||||||
|
|
||||||
|
BackendsCMDFlags `embed:""`
|
||||||
|
}
|
||||||
|
|
||||||
type BackendsCMD struct {
|
type BackendsCMD struct {
|
||||||
List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"`
|
List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"`
|
||||||
Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"`
|
Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"`
|
||||||
Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"`
|
Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"`
|
||||||
|
Upgrade BackendsUpgrade `cmd:"" help:"Upgrade backends to latest versions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||||
@@ -64,11 +71,27 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for upgrades
|
||||||
|
upgrades, _ := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||||
|
|
||||||
for _, backend := range backends {
|
for _, backend := range backends {
|
||||||
|
versionStr := ""
|
||||||
|
if backend.Version != "" {
|
||||||
|
versionStr = " v" + backend.Version
|
||||||
|
}
|
||||||
if backend.Installed {
|
if backend.Installed {
|
||||||
fmt.Printf(" * %s@%s (installed)\n", backend.Gallery.Name, backend.Name)
|
if info, ok := upgrades[backend.Name]; ok {
|
||||||
|
upgradeStr := info.AvailableVersion
|
||||||
|
if upgradeStr == "" {
|
||||||
|
upgradeStr = "new build"
|
||||||
|
}
|
||||||
|
fmt.Printf(" * %s@%s%s (installed, upgrade available: %s)\n", backend.Gallery.Name, backend.Name, versionStr, upgradeStr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" * %s@%s%s (installed)\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf(" - %s@%s\n", backend.Gallery.Name, backend.Name)
|
fmt.Printf(" - %s@%s%s\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -111,6 +134,79 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bu *BackendsUpgrade) Run(ctx *cliContext.Context) error {
|
||||||
|
var galleries []config.Gallery
|
||||||
|
if err := json.Unmarshal([]byte(bu.BackendGalleries), &galleries); err != nil {
|
||||||
|
xlog.Error("unable to load galleries", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
systemState, err := system.GetSystemState(
|
||||||
|
system.WithBackendSystemPath(bu.BackendsSystemPath),
|
||||||
|
system.WithBackendPath(bu.BackendsPath),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
upgrades, err := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check for upgrades: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upgrades) == 0 {
|
||||||
|
fmt.Println("All backends are up to date.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter to specified backends if args given
|
||||||
|
toUpgrade := upgrades
|
||||||
|
if len(bu.BackendArgs) > 0 {
|
||||||
|
toUpgrade = make(map[string]gallery.UpgradeInfo)
|
||||||
|
for _, name := range bu.BackendArgs {
|
||||||
|
if info, ok := upgrades[name]; ok {
|
||||||
|
toUpgrade[name] = info
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Backend %s: no upgrade available\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toUpgrade) == 0 {
|
||||||
|
fmt.Println("No upgrades to apply.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelLoader := model.NewModelLoader(systemState)
|
||||||
|
for name, info := range toUpgrade {
|
||||||
|
versionStr := ""
|
||||||
|
if info.AvailableVersion != "" {
|
||||||
|
versionStr = " to v" + info.AvailableVersion
|
||||||
|
}
|
||||||
|
fmt.Printf("Upgrading %s%s...\n", name, versionStr)
|
||||||
|
|
||||||
|
progressBar := progressbar.NewOptions(
|
||||||
|
1000,
|
||||||
|
progressbar.OptionSetDescription(fmt.Sprintf("downloading %s", name)),
|
||||||
|
progressbar.OptionShowBytes(false),
|
||||||
|
progressbar.OptionClearOnFinish(),
|
||||||
|
)
|
||||||
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||||
|
v := int(percentage * 10)
|
||||||
|
if err := progressBar.Set(v); err != nil {
|
||||||
|
xlog.Error("error updating progress bar", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback); err != nil {
|
||||||
|
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Backend %s upgraded successfully\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
|
func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
|
||||||
for _, backendName := range bu.BackendArgs {
|
for _, backendName := range bu.BackendArgs {
|
||||||
xlog.Info("uninstalling backend", "backend", backendName)
|
xlog.Info("uninstalling backend", "backend", backendName)
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type RunCMD struct {
|
|||||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||||
|
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||||
@@ -62,6 +63,7 @@ type RunCMD struct {
|
|||||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||||
|
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||||
@@ -295,6 +297,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
opts = append(opts, config.DisableWebUI)
|
opts = append(opts, config.DisableWebUI)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.OllamaAPIRootEndpoint {
|
||||||
|
opts = append(opts, config.EnableOllamaAPIRootEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
if r.DisableGalleryEndpoint {
|
if r.DisableGalleryEndpoint {
|
||||||
opts = append(opts, config.DisableGalleryEndpoint)
|
opts = append(opts, config.DisableGalleryEndpoint)
|
||||||
}
|
}
|
||||||
@@ -485,6 +491,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
opts = append(opts, config.EnableBackendGalleriesAutoload)
|
opts = append(opts, config.EnableBackendGalleriesAutoload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.AutoUpgradeBackends {
|
||||||
|
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
||||||
|
}
|
||||||
|
|
||||||
if r.PreloadBackendOnly {
|
if r.PreloadBackendOnly {
|
||||||
_, err := application.New(opts...)
|
_, err := application.New(opts...)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -512,11 +512,9 @@ func (s *backendSupervisor) stopBackend(backend string) {
|
|||||||
|
|
||||||
// Network I/O outside the lock
|
// Network I/O outside the lock
|
||||||
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
||||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
if err := client.Free(context.Background()); err != nil {
|
||||||
if err := freeFunc.Free(context.Background()); err != nil {
|
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
||||||
@@ -692,13 +690,13 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
|||||||
|
|
||||||
// backend.delete — stop backend + delete files (request-reply)
|
// backend.delete — stop backend + delete files (request-reply)
|
||||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
xlog.Info("Received NATS backend.delete event")
|
|
||||||
var req messaging.BackendDeleteRequest
|
var req messaging.BackendDeleteRequest
|
||||||
if err := json.Unmarshal(data, &req); err != nil {
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||||
replyJSON(reply, resp)
|
replyJSON(reply, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
xlog.Info("Received NATS backend.delete event", "backend", req.Backend)
|
||||||
|
|
||||||
// Stop if running this backend
|
// Stop if running this backend
|
||||||
if s.isRunning(req.Backend) {
|
if s.isRunning(req.Backend) {
|
||||||
@@ -774,10 +772,8 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
|||||||
if targetAddr != "" {
|
if targetAddr != "" {
|
||||||
// Best-effort gRPC Free()
|
// Best-effort gRPC Free()
|
||||||
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
if err := client.Free(context.Background()); err != nil {
|
||||||
if err := freeFunc.Free(context.Background()); err != nil {
|
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type ApplicationConfig struct {
|
|||||||
Federated bool
|
Federated bool
|
||||||
|
|
||||||
DisableWebUI bool
|
DisableWebUI bool
|
||||||
|
OllamaAPIRootEndpoint bool
|
||||||
EnforcePredownloadScans bool
|
EnforcePredownloadScans bool
|
||||||
OpaqueErrors bool
|
OpaqueErrors bool
|
||||||
UseSubtleKeyComparison bool
|
UseSubtleKeyComparison bool
|
||||||
@@ -56,6 +57,7 @@ type ApplicationConfig struct {
|
|||||||
ExternalGRPCBackends map[string]string
|
ExternalGRPCBackends map[string]string
|
||||||
|
|
||||||
AutoloadGalleries, AutoloadBackendGalleries bool
|
AutoloadGalleries, AutoloadBackendGalleries bool
|
||||||
|
AutoUpgradeBackends bool
|
||||||
|
|
||||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||||
@@ -263,6 +265,10 @@ var DisableWebUI = func(o *ApplicationConfig) {
|
|||||||
o.DisableWebUI = true
|
o.DisableWebUI = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var EnableOllamaAPIRootEndpoint = func(o *ApplicationConfig) {
|
||||||
|
o.OllamaAPIRootEndpoint = true
|
||||||
|
}
|
||||||
|
|
||||||
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
||||||
o.DisableRuntimeSettings = true
|
o.DisableRuntimeSettings = true
|
||||||
}
|
}
|
||||||
@@ -385,6 +391,10 @@ var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) {
|
|||||||
o.AutoloadBackendGalleries = true
|
o.AutoloadBackendGalleries = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithAutoUpgradeBackends(v bool) AppOption {
|
||||||
|
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
||||||
|
}
|
||||||
|
|
||||||
var EnableFederated = func(o *ApplicationConfig) {
|
var EnableFederated = func(o *ApplicationConfig) {
|
||||||
o.Federated = true
|
o.Federated = true
|
||||||
}
|
}
|
||||||
@@ -857,6 +867,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
|||||||
backendGalleries := o.BackendGalleries
|
backendGalleries := o.BackendGalleries
|
||||||
autoloadGalleries := o.AutoloadGalleries
|
autoloadGalleries := o.AutoloadGalleries
|
||||||
autoloadBackendGalleries := o.AutoloadBackendGalleries
|
autoloadBackendGalleries := o.AutoloadBackendGalleries
|
||||||
|
autoUpgradeBackends := o.AutoUpgradeBackends
|
||||||
apiKeys := o.ApiKeys
|
apiKeys := o.ApiKeys
|
||||||
agentJobRetentionDays := o.AgentJobRetentionDays
|
agentJobRetentionDays := o.AgentJobRetentionDays
|
||||||
|
|
||||||
@@ -930,6 +941,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
|||||||
BackendGalleries: &backendGalleries,
|
BackendGalleries: &backendGalleries,
|
||||||
AutoloadGalleries: &autoloadGalleries,
|
AutoloadGalleries: &autoloadGalleries,
|
||||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||||
|
AutoUpgradeBackends: &autoUpgradeBackends,
|
||||||
ApiKeys: &apiKeys,
|
ApiKeys: &apiKeys,
|
||||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||||
@@ -1078,6 +1090,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
|||||||
if settings.AutoloadBackendGalleries != nil {
|
if settings.AutoloadBackendGalleries != nil {
|
||||||
o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||||
}
|
}
|
||||||
|
if settings.AutoUpgradeBackends != nil {
|
||||||
|
o.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||||
|
}
|
||||||
if settings.AgentJobRetentionDays != nil {
|
if settings.AgentJobRetentionDays != nil {
|
||||||
o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
|
o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,6 +119,13 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
|||||||
Expect(*rs.AgentJobRetentionDays).To(Equal(30))
|
Expect(*rs.AgentJobRetentionDays).To(Equal(30))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("should include auto_upgrade_backends", func() {
|
||||||
|
appConfig := &ApplicationConfig{AutoUpgradeBackends: true}
|
||||||
|
rs := appConfig.ToRuntimeSettings()
|
||||||
|
Expect(rs.AutoUpgradeBackends).ToNot(BeNil())
|
||||||
|
Expect(*rs.AutoUpgradeBackends).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
It("should use default timeouts when not set", func() {
|
It("should use default timeouts when not set", func() {
|
||||||
appConfig := &ApplicationConfig{}
|
appConfig := &ApplicationConfig{}
|
||||||
|
|
||||||
@@ -426,6 +433,14 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
|||||||
Expect(appConfig.AutoloadBackendGalleries).To(BeTrue())
|
Expect(appConfig.AutoloadBackendGalleries).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("should apply auto_upgrade_backends setting", func() {
|
||||||
|
appConfig := &ApplicationConfig{}
|
||||||
|
v := true
|
||||||
|
rs := &RuntimeSettings{AutoUpgradeBackends: &v}
|
||||||
|
appConfig.ApplyRuntimeSettings(rs)
|
||||||
|
Expect(appConfig.AutoUpgradeBackends).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
It("should apply agent settings", func() {
|
It("should apply agent settings", func() {
|
||||||
appConfig := &ApplicationConfig{}
|
appConfig := &ApplicationConfig{}
|
||||||
|
|
||||||
@@ -465,6 +480,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
|||||||
Federated: true,
|
Federated: true,
|
||||||
AutoloadGalleries: true,
|
AutoloadGalleries: true,
|
||||||
AutoloadBackendGalleries: false,
|
AutoloadBackendGalleries: false,
|
||||||
|
AutoUpgradeBackends: true,
|
||||||
AgentJobRetentionDays: 60,
|
AgentJobRetentionDays: 60,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,6 +512,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
|||||||
Expect(target.Federated).To(Equal(original.Federated))
|
Expect(target.Federated).To(Equal(original.Federated))
|
||||||
Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries))
|
Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries))
|
||||||
Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries))
|
Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries))
|
||||||
|
Expect(target.AutoUpgradeBackends).To(Equal(original.AutoUpgradeBackends))
|
||||||
Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays))
|
Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
"qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
"qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||||
"qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
"qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||||
"qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95},
|
"qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95},
|
||||||
|
"gemma-4": {"min_p":0,"presence_penalty":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||||
"gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
"gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||||
"gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
"gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||||
"medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
"medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||||
@@ -53,5 +54,5 @@
|
|||||||
"grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
"grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||||
"mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}
|
"mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}
|
||||||
},
|
},
|
||||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-4","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||||
}
|
}
|
||||||
|
|||||||
132
core/config/meta/build.go
Normal file
132
core/config/meta/build.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package meta
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedMetadata *ConfigMetadata
|
||||||
|
cacheMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildConfigMetadata reflects on the given struct type (ModelConfig),
|
||||||
|
// merges the enrichment registry, and returns the full ConfigMetadata.
|
||||||
|
// The result is cached in memory after the first call.
|
||||||
|
func BuildConfigMetadata(modelConfigType reflect.Type) *ConfigMetadata {
|
||||||
|
cacheMu.RLock()
|
||||||
|
if cachedMetadata != nil {
|
||||||
|
cacheMu.RUnlock()
|
||||||
|
return cachedMetadata
|
||||||
|
}
|
||||||
|
cacheMu.RUnlock()
|
||||||
|
|
||||||
|
cacheMu.Lock()
|
||||||
|
defer cacheMu.Unlock()
|
||||||
|
|
||||||
|
if cachedMetadata != nil {
|
||||||
|
return cachedMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
cachedMetadata = buildConfigMetadataUncached(modelConfigType, DefaultRegistry())
|
||||||
|
return cachedMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildConfigMetadataUncached does the actual work without caching.
|
||||||
|
func buildConfigMetadataUncached(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||||
|
fields := WalkModelConfig(modelConfigType)
|
||||||
|
|
||||||
|
for i := range fields {
|
||||||
|
override, ok := registry[fields[i].Path]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
applyOverride(&fields[i], override)
|
||||||
|
}
|
||||||
|
|
||||||
|
allSections := DefaultSections()
|
||||||
|
|
||||||
|
sectionOrder := make(map[string]int, len(allSections))
|
||||||
|
for _, s := range allSections {
|
||||||
|
sectionOrder[s.ID] = s.Order
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.SliceStable(fields, func(i, j int) bool {
|
||||||
|
si := sectionOrder[fields[i].Section]
|
||||||
|
sj := sectionOrder[fields[j].Section]
|
||||||
|
if si != sj {
|
||||||
|
return si < sj
|
||||||
|
}
|
||||||
|
return fields[i].Order < fields[j].Order
|
||||||
|
})
|
||||||
|
|
||||||
|
usedSections := make(map[string]bool)
|
||||||
|
for _, f := range fields {
|
||||||
|
usedSections[f.Section] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var sections []Section
|
||||||
|
for _, s := range allSections {
|
||||||
|
if usedSections[s.ID] {
|
||||||
|
sections = append(sections, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ConfigMetadata{
|
||||||
|
Sections: sections,
|
||||||
|
Fields: fields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOverride merges non-zero override values into the field.
|
||||||
|
func applyOverride(f *FieldMeta, o FieldMetaOverride) {
|
||||||
|
if o.Section != "" {
|
||||||
|
f.Section = o.Section
|
||||||
|
}
|
||||||
|
if o.Label != "" {
|
||||||
|
f.Label = o.Label
|
||||||
|
}
|
||||||
|
if o.Description != "" {
|
||||||
|
f.Description = o.Description
|
||||||
|
}
|
||||||
|
if o.Component != "" {
|
||||||
|
f.Component = o.Component
|
||||||
|
}
|
||||||
|
if o.Placeholder != "" {
|
||||||
|
f.Placeholder = o.Placeholder
|
||||||
|
}
|
||||||
|
if o.Default != nil {
|
||||||
|
f.Default = o.Default
|
||||||
|
}
|
||||||
|
if o.Min != nil {
|
||||||
|
f.Min = o.Min
|
||||||
|
}
|
||||||
|
if o.Max != nil {
|
||||||
|
f.Max = o.Max
|
||||||
|
}
|
||||||
|
if o.Step != nil {
|
||||||
|
f.Step = o.Step
|
||||||
|
}
|
||||||
|
if o.Options != nil {
|
||||||
|
f.Options = o.Options
|
||||||
|
}
|
||||||
|
if o.AutocompleteProvider != "" {
|
||||||
|
f.AutocompleteProvider = o.AutocompleteProvider
|
||||||
|
}
|
||||||
|
if o.VRAMImpact {
|
||||||
|
f.VRAMImpact = true
|
||||||
|
}
|
||||||
|
if o.Advanced {
|
||||||
|
f.Advanced = true
|
||||||
|
}
|
||||||
|
if o.Order != 0 {
|
||||||
|
f.Order = o.Order
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildForTest builds metadata without caching, for use in tests.
|
||||||
|
func BuildForTest(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||||
|
return buildConfigMetadataUncached(modelConfigType, registry)
|
||||||
|
}
|
||||||
|
|
||||||
211
core/config/meta/build_test.go
Normal file
211
core/config/meta/build_test.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package meta_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/config/meta"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildConfigMetadata(t *testing.T) {
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||||
|
|
||||||
|
if len(md.Sections) == 0 {
|
||||||
|
t.Fatal("expected sections, got 0")
|
||||||
|
}
|
||||||
|
if len(md.Fields) == 0 {
|
||||||
|
t.Fatal("expected fields, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify sections are ordered
|
||||||
|
for i := 1; i < len(md.Sections); i++ {
|
||||||
|
if md.Sections[i].Order < md.Sections[i-1].Order {
|
||||||
|
t.Errorf("sections not ordered: %s (order=%d) before %s (order=%d)",
|
||||||
|
md.Sections[i-1].ID, md.Sections[i-1].Order,
|
||||||
|
md.Sections[i].ID, md.Sections[i].Order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryOverrides(t *testing.T) {
|
||||||
|
registry := map[string]meta.FieldMetaOverride{
|
||||||
|
"name": {
|
||||||
|
Label: "My Custom Label",
|
||||||
|
Description: "Custom description",
|
||||||
|
Component: "textarea",
|
||||||
|
Order: 999,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), registry)
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := byPath["name"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("field 'name' not found")
|
||||||
|
}
|
||||||
|
if f.Label != "My Custom Label" {
|
||||||
|
t.Errorf("expected label 'My Custom Label', got %q", f.Label)
|
||||||
|
}
|
||||||
|
if f.Description != "Custom description" {
|
||||||
|
t.Errorf("expected description 'Custom description', got %q", f.Description)
|
||||||
|
}
|
||||||
|
if f.Component != "textarea" {
|
||||||
|
t.Errorf("expected component 'textarea', got %q", f.Component)
|
||||||
|
}
|
||||||
|
if f.Order != 999 {
|
||||||
|
t.Errorf("expected order 999, got %d", f.Order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnregisteredFieldsGetDefaults(t *testing.T) {
|
||||||
|
// Use empty registry - all fields should still get auto-generated metadata
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), map[string]meta.FieldMetaOverride{})
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// context_size should still exist with auto-generated label
|
||||||
|
f, ok := byPath["context_size"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("field 'context_size' not found")
|
||||||
|
}
|
||||||
|
if f.Label == "" {
|
||||||
|
t.Error("expected auto-generated label, got empty")
|
||||||
|
}
|
||||||
|
if f.UIType != "int" {
|
||||||
|
t.Errorf("expected UIType 'int', got %q", f.UIType)
|
||||||
|
}
|
||||||
|
if f.Component == "" {
|
||||||
|
t.Error("expected auto-generated component, got empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultRegistryOverridesApply(t *testing.T) {
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify enriched fields got their overrides
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
label string
|
||||||
|
description string
|
||||||
|
vramImpact bool
|
||||||
|
}{
|
||||||
|
{"context_size", "Context Size", "Maximum context window in tokens", true},
|
||||||
|
{"gpu_layers", "GPU Layers", "Number of layers to offload to GPU (-1 = all)", true},
|
||||||
|
{"backend", "Backend", "The inference backend to use (e.g. llama-cpp, vllm, diffusers)", false},
|
||||||
|
{"parameters.temperature", "Temperature", "Sampling temperature (higher = more creative, lower = more deterministic)", false},
|
||||||
|
{"template.chat", "Chat Template", "Go template for chat completion requests", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
f, ok := byPath[tt.path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", tt.path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Label != tt.label {
|
||||||
|
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||||
|
}
|
||||||
|
if f.Description != tt.description {
|
||||||
|
t.Errorf("field %q: expected description %q, got %q", tt.path, tt.description, f.Description)
|
||||||
|
}
|
||||||
|
if f.VRAMImpact != tt.vramImpact {
|
||||||
|
t.Errorf("field %q: expected vramImpact=%v, got %v", tt.path, tt.vramImpact, f.VRAMImpact)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticOptionsFields(t *testing.T) {
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields with static options should have Options populated and no AutocompleteProvider
|
||||||
|
staticFields := []string{"quantization", "cache_type_k", "cache_type_v", "diffusers.pipeline_type", "diffusers.scheduler_type"}
|
||||||
|
for _, path := range staticFields {
|
||||||
|
f, ok := byPath[path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(f.Options) == 0 {
|
||||||
|
t.Errorf("field %q: expected Options to be populated", path)
|
||||||
|
}
|
||||||
|
if f.AutocompleteProvider != "" {
|
||||||
|
t.Errorf("field %q: expected no AutocompleteProvider, got %q", path, f.AutocompleteProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDynamicProviderFields(t *testing.T) {
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields with dynamic providers should have AutocompleteProvider and no Options
|
||||||
|
dynamicFields := map[string]string{
|
||||||
|
"backend": meta.ProviderBackends,
|
||||||
|
"pipeline.llm": meta.ProviderModelsChat,
|
||||||
|
"pipeline.tts": meta.ProviderModelsTTS,
|
||||||
|
"pipeline.transcription": meta.ProviderModelsTranscript,
|
||||||
|
"pipeline.vad": meta.ProviderModelsVAD,
|
||||||
|
}
|
||||||
|
for path, expectedProvider := range dynamicFields {
|
||||||
|
f, ok := byPath[path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.AutocompleteProvider != expectedProvider {
|
||||||
|
t.Errorf("field %q: expected AutocompleteProvider %q, got %q", path, expectedProvider, f.AutocompleteProvider)
|
||||||
|
}
|
||||||
|
if len(f.Options) != 0 {
|
||||||
|
t.Errorf("field %q: expected no Options, got %d", path, len(f.Options))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVRAMImpactFields(t *testing.T) {
|
||||||
|
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||||
|
|
||||||
|
var vramFields []string
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
if f.VRAMImpact {
|
||||||
|
vramFields = append(vramFields, f.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vramFields) == 0 {
|
||||||
|
t.Error("expected some VRAM impact fields, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// context_size and gpu_layers should be marked
|
||||||
|
expected := map[string]bool{"context_size": true, "gpu_layers": true}
|
||||||
|
for _, path := range vramFields {
|
||||||
|
if expected[path] {
|
||||||
|
delete(expected, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for path := range expected {
|
||||||
|
t.Errorf("expected VRAM impact field %q not found", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
79
core/config/meta/constants.go
Normal file
79
core/config/meta/constants.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package meta
|
||||||
|
|
||||||
|
// Dynamic autocomplete provider constants (runtime lookup required).
|
||||||
|
const (
|
||||||
|
ProviderBackends = "backends"
|
||||||
|
ProviderModels = "models"
|
||||||
|
ProviderModelsChat = "models:chat"
|
||||||
|
ProviderModelsTTS = "models:tts"
|
||||||
|
ProviderModelsTranscript = "models:transcript"
|
||||||
|
ProviderModelsVAD = "models:vad"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Static option lists embedded directly in field metadata.
|
||||||
|
|
||||||
|
var QuantizationOptions = []FieldOption{
|
||||||
|
{Value: "q4_0", Label: "Q4_0"},
|
||||||
|
{Value: "q4_1", Label: "Q4_1"},
|
||||||
|
{Value: "q5_0", Label: "Q5_0"},
|
||||||
|
{Value: "q5_1", Label: "Q5_1"},
|
||||||
|
{Value: "q8_0", Label: "Q8_0"},
|
||||||
|
{Value: "q2_K", Label: "Q2_K"},
|
||||||
|
{Value: "q3_K_S", Label: "Q3_K_S"},
|
||||||
|
{Value: "q3_K_M", Label: "Q3_K_M"},
|
||||||
|
{Value: "q3_K_L", Label: "Q3_K_L"},
|
||||||
|
{Value: "q4_K_S", Label: "Q4_K_S"},
|
||||||
|
{Value: "q4_K_M", Label: "Q4_K_M"},
|
||||||
|
{Value: "q5_K_S", Label: "Q5_K_S"},
|
||||||
|
{Value: "q5_K_M", Label: "Q5_K_M"},
|
||||||
|
{Value: "q6_K", Label: "Q6_K"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var CacheTypeOptions = []FieldOption{
|
||||||
|
{Value: "f16", Label: "F16"},
|
||||||
|
{Value: "f32", Label: "F32"},
|
||||||
|
{Value: "q8_0", Label: "Q8_0"},
|
||||||
|
{Value: "q4_0", Label: "Q4_0"},
|
||||||
|
{Value: "q4_1", Label: "Q4_1"},
|
||||||
|
{Value: "q5_0", Label: "Q5_0"},
|
||||||
|
{Value: "q5_1", Label: "Q5_1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var DiffusersPipelineOptions = []FieldOption{
|
||||||
|
{Value: "StableDiffusionPipeline", Label: "StableDiffusionPipeline"},
|
||||||
|
{Value: "StableDiffusionImg2ImgPipeline", Label: "StableDiffusionImg2ImgPipeline"},
|
||||||
|
{Value: "StableDiffusionXLPipeline", Label: "StableDiffusionXLPipeline"},
|
||||||
|
{Value: "StableDiffusionXLImg2ImgPipeline", Label: "StableDiffusionXLImg2ImgPipeline"},
|
||||||
|
{Value: "StableDiffusionDepth2ImgPipeline", Label: "StableDiffusionDepth2ImgPipeline"},
|
||||||
|
{Value: "DiffusionPipeline", Label: "DiffusionPipeline"},
|
||||||
|
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var UsecaseOptions = []FieldOption{
|
||||||
|
{Value: "chat", Label: "Chat"},
|
||||||
|
{Value: "completion", Label: "Completion"},
|
||||||
|
{Value: "edit", Label: "Edit"},
|
||||||
|
{Value: "embeddings", Label: "Embeddings"},
|
||||||
|
{Value: "rerank", Label: "Rerank"},
|
||||||
|
{Value: "image", Label: "Image"},
|
||||||
|
{Value: "transcript", Label: "Transcript"},
|
||||||
|
{Value: "tts", Label: "TTS"},
|
||||||
|
{Value: "sound_generation", Label: "Sound Generation"},
|
||||||
|
{Value: "tokenize", Label: "Tokenize"},
|
||||||
|
{Value: "vad", Label: "VAD"},
|
||||||
|
{Value: "video", Label: "Video"},
|
||||||
|
{Value: "detection", Label: "Detection"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var DiffusersSchedulerOptions = []FieldOption{
|
||||||
|
{Value: "ddim", Label: "DDIM"},
|
||||||
|
{Value: "ddpm", Label: "DDPM"},
|
||||||
|
{Value: "pndm", Label: "PNDM"},
|
||||||
|
{Value: "lms", Label: "LMS"},
|
||||||
|
{Value: "euler", Label: "Euler"},
|
||||||
|
{Value: "euler_a", Label: "Euler A"},
|
||||||
|
{Value: "dpm_multistep", Label: "DPM Multistep"},
|
||||||
|
{Value: "dpm_singlestep", Label: "DPM Singlestep"},
|
||||||
|
{Value: "heun", Label: "Heun"},
|
||||||
|
{Value: "unipc", Label: "UniPC"},
|
||||||
|
}
|
||||||
241
core/config/meta/reflect.go
Normal file
241
core/config/meta/reflect.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
package meta
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WalkModelConfig uses reflection to discover all exported, YAML-tagged fields
|
||||||
|
// in the given struct type (expected to be config.ModelConfig) and returns a
|
||||||
|
// slice of FieldMeta with sensible defaults derived from the type information.
|
||||||
|
func WalkModelConfig(t reflect.Type) []FieldMeta {
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
var fields []FieldMeta
|
||||||
|
walkStruct(t, "", &fields)
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// walkStruct recursively walks a struct type, collecting FieldMeta entries.
|
||||||
|
// prefix is the dot-path prefix for nested structs (e.g. "function.grammar.").
|
||||||
|
func walkStruct(t reflect.Type, prefix string, out *[]FieldMeta) {
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
if t.Kind() != reflect.Struct {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for sf := range t.Fields() {
|
||||||
|
if !sf.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
yamlTag := sf.Tag.Get("yaml")
|
||||||
|
if yamlTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
yamlKey, opts := parseTag(yamlTag)
|
||||||
|
|
||||||
|
// Handle inline embedding (e.g. LLMConfig `yaml:",inline"`)
|
||||||
|
if opts.contains("inline") {
|
||||||
|
ft := sf.Type
|
||||||
|
if ft.Kind() == reflect.Pointer {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
if ft.Kind() == reflect.Struct {
|
||||||
|
walkStruct(ft, prefix, out)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no yaml key and it's an embedded struct without inline, skip unknown pattern
|
||||||
|
if yamlKey == "" {
|
||||||
|
ft := sf.Type
|
||||||
|
if ft.Kind() == reflect.Pointer {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
// Anonymous struct without yaml tag - treat as inline
|
||||||
|
if sf.Anonymous && ft.Kind() == reflect.Struct {
|
||||||
|
walkStruct(ft, prefix, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Named field without yaml tag - skip
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ft := sf.Type
|
||||||
|
isPtr := ft.Kind() == reflect.Pointer
|
||||||
|
if isPtr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Named nested struct (not a special type) -> recurse with prefix
|
||||||
|
if ft.Kind() == reflect.Struct && !isSpecialType(ft) {
|
||||||
|
nestedPrefix := prefix + yamlKey + "."
|
||||||
|
walkStruct(ft, nestedPrefix, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leaf field
|
||||||
|
path := prefix + yamlKey
|
||||||
|
goType := sf.Type.String()
|
||||||
|
uiType, component := inferUIType(sf.Type)
|
||||||
|
section := inferSection(prefix)
|
||||||
|
label := labelFromKey(yamlKey)
|
||||||
|
|
||||||
|
*out = append(*out, FieldMeta{
|
||||||
|
Path: path,
|
||||||
|
YAMLKey: yamlKey,
|
||||||
|
GoType: goType,
|
||||||
|
UIType: uiType,
|
||||||
|
Pointer: isPtr,
|
||||||
|
Section: section,
|
||||||
|
Label: label,
|
||||||
|
Component: component,
|
||||||
|
Order: len(*out),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSpecialType returns true for struct types that should be treated as leaf
|
||||||
|
// values rather than recursed into (e.g. custom JSON marshalers).
|
||||||
|
func isSpecialType(t reflect.Type) bool {
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
name := t.Name()
|
||||||
|
// LogprobsValue, URI types are leaf values despite being structs
|
||||||
|
switch name {
|
||||||
|
case "LogprobsValue", "URI":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferUIType maps a Go reflect.Type to a UI type string and default component.
|
||||||
|
func inferUIType(t reflect.Type) (uiType, component string) {
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return "bool", "toggle"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return "int", "number"
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return "int", "number"
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return "float", "number"
|
||||||
|
case reflect.String:
|
||||||
|
return "string", "input"
|
||||||
|
case reflect.Slice:
|
||||||
|
elem := t.Elem()
|
||||||
|
if elem.Kind() == reflect.String {
|
||||||
|
return "[]string", "string-list"
|
||||||
|
}
|
||||||
|
if elem.Kind() == reflect.Pointer {
|
||||||
|
elem = elem.Elem()
|
||||||
|
}
|
||||||
|
if elem.Kind() == reflect.Struct {
|
||||||
|
return "[]object", "json-editor"
|
||||||
|
}
|
||||||
|
return "[]any", "json-editor"
|
||||||
|
case reflect.Map:
|
||||||
|
return "map", "map-editor"
|
||||||
|
case reflect.Struct:
|
||||||
|
// Special types treated as leaves
|
||||||
|
if isSpecialType(t) {
|
||||||
|
return "bool", "toggle" // LogprobsValue
|
||||||
|
}
|
||||||
|
return "object", "json-editor"
|
||||||
|
default:
|
||||||
|
return "any", "input"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferSection determines the config section from the dot-path prefix.
|
||||||
|
func inferSection(prefix string) string {
|
||||||
|
if prefix == "" {
|
||||||
|
return "general"
|
||||||
|
}
|
||||||
|
// Remove trailing dot
|
||||||
|
p := strings.TrimSuffix(prefix, ".")
|
||||||
|
|
||||||
|
// Use the top-level prefix to determine section
|
||||||
|
parts := strings.SplitN(p, ".", 2)
|
||||||
|
top := parts[0]
|
||||||
|
|
||||||
|
switch top {
|
||||||
|
case "parameters":
|
||||||
|
return "parameters"
|
||||||
|
case "template":
|
||||||
|
return "templates"
|
||||||
|
case "function":
|
||||||
|
return "functions"
|
||||||
|
case "reasoning":
|
||||||
|
return "reasoning"
|
||||||
|
case "diffusers":
|
||||||
|
return "diffusers"
|
||||||
|
case "tts":
|
||||||
|
return "tts"
|
||||||
|
case "pipeline":
|
||||||
|
return "pipeline"
|
||||||
|
case "grpc":
|
||||||
|
return "grpc"
|
||||||
|
case "agent":
|
||||||
|
return "agent"
|
||||||
|
case "mcp":
|
||||||
|
return "mcp"
|
||||||
|
case "feature_flags":
|
||||||
|
return "other"
|
||||||
|
case "limit_mm_per_prompt":
|
||||||
|
return "llm"
|
||||||
|
default:
|
||||||
|
return "other"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// labelFromKey converts a yaml key like "context_size" to "Context Size".
|
||||||
|
func labelFromKey(key string) string {
|
||||||
|
parts := strings.Split(key, "_")
|
||||||
|
for i, p := range parts {
|
||||||
|
if len(p) > 0 {
|
||||||
|
runes := []rune(p)
|
||||||
|
runes[0] = unicode.ToUpper(runes[0])
|
||||||
|
parts[i] = string(runes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// tagOptions is a set of comma-separated yaml tag options.
|
||||||
|
type tagOptions string
|
||||||
|
|
||||||
|
func (o tagOptions) contains(optName string) bool {
|
||||||
|
s := string(o)
|
||||||
|
for s != "" {
|
||||||
|
var name string
|
||||||
|
if name, s, _ = strings.Cut(s, ","); name == optName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTag splits a yaml struct tag into the key name and options.
|
||||||
|
func parseTag(tag string) (string, tagOptions) {
|
||||||
|
if tag == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
before, after, found := strings.Cut(tag, ",")
|
||||||
|
if found {
|
||||||
|
return before, tagOptions(after)
|
||||||
|
}
|
||||||
|
return tag, ""
|
||||||
|
}
|
||||||
|
|
||||||
208
core/config/meta/reflect_test.go
Normal file
208
core/config/meta/reflect_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package meta_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/config/meta"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWalkModelConfig(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
if len(fields) == 0 {
|
||||||
|
t.Fatal("expected fields from ModelConfig, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a lookup by path
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify some top-level fields exist
|
||||||
|
for _, path := range []string{"name", "backend", "cuda", "step"} {
|
||||||
|
if _, ok := byPath[path]; !ok {
|
||||||
|
t.Errorf("expected field %q not found", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify inline LLMConfig fields appear at top level (no prefix)
|
||||||
|
for _, path := range []string{"context_size", "gpu_layers", "threads", "mmap"} {
|
||||||
|
if _, ok := byPath[path]; !ok {
|
||||||
|
t.Errorf("expected inline LLMConfig field %q not found", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify nested struct fields have correct prefix
|
||||||
|
for _, path := range []string{
|
||||||
|
"template.chat",
|
||||||
|
"template.completion",
|
||||||
|
"template.use_tokenizer_template",
|
||||||
|
"function.grammar.parallel_calls",
|
||||||
|
"function.grammar.mixed_mode",
|
||||||
|
"diffusers.pipeline_type",
|
||||||
|
"diffusers.cuda",
|
||||||
|
"pipeline.llm",
|
||||||
|
"pipeline.tts",
|
||||||
|
"reasoning.disable",
|
||||||
|
"agent.max_iterations",
|
||||||
|
"grpc.attempts",
|
||||||
|
} {
|
||||||
|
if _, ok := byPath[path]; !ok {
|
||||||
|
t.Errorf("expected nested field %q not found", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify PredictionOptions fields have parameters. prefix
|
||||||
|
for _, path := range []string{
|
||||||
|
"parameters.temperature",
|
||||||
|
"parameters.top_p",
|
||||||
|
"parameters.top_k",
|
||||||
|
"parameters.max_tokens",
|
||||||
|
"parameters.seed",
|
||||||
|
} {
|
||||||
|
if _, ok := byPath[path]; !ok {
|
||||||
|
t.Errorf("expected parameters field %q not found", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify TTSConfig fields have tts. prefix
|
||||||
|
if _, ok := byPath["tts.voice"]; !ok {
|
||||||
|
t.Error("expected tts.voice field not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipsYAMLDashFields(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelConfigFile has yaml:"-" tag, should be skipped
|
||||||
|
for _, f := range fields {
|
||||||
|
if f.Path == "modelConfigFile" || f.Path == "modelTemplate" {
|
||||||
|
t.Errorf("field %q should have been skipped (yaml:\"-\")", f.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeMapping(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
uiType string
|
||||||
|
pointer bool
|
||||||
|
}{
|
||||||
|
{"name", "string", false},
|
||||||
|
{"cuda", "bool", false},
|
||||||
|
{"context_size", "int", true},
|
||||||
|
{"gpu_layers", "int", true},
|
||||||
|
{"threads", "int", true},
|
||||||
|
{"f16", "bool", true},
|
||||||
|
{"mmap", "bool", true},
|
||||||
|
{"stopwords", "[]string", false},
|
||||||
|
{"roles", "map", false},
|
||||||
|
{"parameters.temperature", "float", true},
|
||||||
|
{"parameters.top_k", "int", true},
|
||||||
|
{"function.grammar.parallel_calls", "bool", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
f, ok := byPath[tt.path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", tt.path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.UIType != tt.uiType {
|
||||||
|
t.Errorf("field %q: expected UIType %q, got %q", tt.path, tt.uiType, f.UIType)
|
||||||
|
}
|
||||||
|
if f.Pointer != tt.pointer {
|
||||||
|
t.Errorf("field %q: expected Pointer=%v, got %v", tt.path, tt.pointer, f.Pointer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSectionAssignment(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
section string
|
||||||
|
}{
|
||||||
|
{"name", "general"},
|
||||||
|
{"backend", "general"},
|
||||||
|
{"context_size", "general"}, // inline LLMConfig -> no prefix -> general
|
||||||
|
{"parameters.temperature", "parameters"},
|
||||||
|
{"template.chat", "templates"},
|
||||||
|
{"function.grammar.parallel_calls", "functions"},
|
||||||
|
{"diffusers.cuda", "diffusers"},
|
||||||
|
{"pipeline.llm", "pipeline"},
|
||||||
|
{"reasoning.disable", "reasoning"},
|
||||||
|
{"agent.max_iterations", "agent"},
|
||||||
|
{"grpc.attempts", "grpc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
f, ok := byPath[tt.path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", tt.path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Section != tt.section {
|
||||||
|
t.Errorf("field %q: expected section %q, got %q", tt.path, tt.section, f.Section)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLabelGeneration(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
|
||||||
|
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
label string
|
||||||
|
}{
|
||||||
|
{"context_size", "Context Size"},
|
||||||
|
{"gpu_layers", "Gpu Layers"},
|
||||||
|
{"name", "Name"},
|
||||||
|
{"cuda", "Cuda"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
f, ok := byPath[tt.path]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("field %q not found", tt.path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Label != tt.label {
|
||||||
|
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldCount(t *testing.T) {
|
||||||
|
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
// We expect a large number of fields (100+) given the config complexity
|
||||||
|
if len(fields) < 80 {
|
||||||
|
t.Errorf("expected at least 80 fields, got %d", len(fields))
|
||||||
|
}
|
||||||
|
t.Logf("Total fields discovered: %d", len(fields))
|
||||||
|
}
|
||||||
324
core/config/meta/registry.go
Normal file
324
core/config/meta/registry.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package meta
|
||||||
|
|
||||||
|
// DefaultRegistry returns enrichment overrides for the ~30 most commonly used
|
||||||
|
// config fields. Fields not listed here still appear with auto-generated
|
||||||
|
// labels and type-inferred components.
|
||||||
|
func DefaultRegistry() map[string]FieldMetaOverride {
|
||||||
|
f64 := func(v float64) *float64 { return &v }
|
||||||
|
|
||||||
|
return map[string]FieldMetaOverride{
|
||||||
|
// --- General ---
|
||||||
|
"name": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "Model Name",
|
||||||
|
Description: "Unique identifier for this model configuration",
|
||||||
|
Component: "input",
|
||||||
|
Order: 0,
|
||||||
|
},
|
||||||
|
"backend": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "Backend",
|
||||||
|
Description: "The inference backend to use (e.g. llama-cpp, vllm, diffusers)",
|
||||||
|
Component: "select",
|
||||||
|
AutocompleteProvider: ProviderBackends,
|
||||||
|
Order: 1,
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "Description",
|
||||||
|
Description: "Human-readable description of what this model does",
|
||||||
|
Component: "textarea",
|
||||||
|
Order: 2,
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "Usage",
|
||||||
|
Description: "Usage instructions or notes",
|
||||||
|
Component: "textarea",
|
||||||
|
Advanced: true,
|
||||||
|
Order: 3,
|
||||||
|
},
|
||||||
|
"cuda": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "CUDA",
|
||||||
|
Description: "Explicitly enable CUDA acceleration",
|
||||||
|
Order: 5,
|
||||||
|
},
|
||||||
|
"known_usecases": {
|
||||||
|
Section: "general",
|
||||||
|
Label: "Known Use Cases",
|
||||||
|
Description: "Capabilities this model supports",
|
||||||
|
Component: "string-list",
|
||||||
|
Options: UsecaseOptions,
|
||||||
|
Order: 6,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- LLM ---
|
||||||
|
"context_size": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Context Size",
|
||||||
|
Description: "Maximum context window in tokens",
|
||||||
|
Component: "number",
|
||||||
|
VRAMImpact: true,
|
||||||
|
Order: 10,
|
||||||
|
},
|
||||||
|
"gpu_layers": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "GPU Layers",
|
||||||
|
Description: "Number of layers to offload to GPU (-1 = all)",
|
||||||
|
Component: "number",
|
||||||
|
Min: f64(-1),
|
||||||
|
VRAMImpact: true,
|
||||||
|
Order: 11,
|
||||||
|
},
|
||||||
|
"threads": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Threads",
|
||||||
|
Description: "Number of CPU threads for inference",
|
||||||
|
Component: "number",
|
||||||
|
Min: f64(1),
|
||||||
|
Order: 12,
|
||||||
|
},
|
||||||
|
"f16": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "F16",
|
||||||
|
Description: "Use 16-bit floating point for key/value cache",
|
||||||
|
Order: 13,
|
||||||
|
},
|
||||||
|
"mmap": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Memory Map",
|
||||||
|
Description: "Use memory-mapped files for model loading",
|
||||||
|
Order: 14,
|
||||||
|
},
|
||||||
|
"mmlock": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Memory Lock",
|
||||||
|
Description: "Lock model memory to prevent swapping",
|
||||||
|
Advanced: true,
|
||||||
|
Order: 15,
|
||||||
|
},
|
||||||
|
"low_vram": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Low VRAM",
|
||||||
|
Description: "Optimize for systems with limited GPU memory",
|
||||||
|
VRAMImpact: true,
|
||||||
|
Order: 16,
|
||||||
|
},
|
||||||
|
"embeddings": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Embeddings",
|
||||||
|
Description: "Enable embedding generation mode",
|
||||||
|
Order: 17,
|
||||||
|
},
|
||||||
|
"quantization": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Quantization",
|
||||||
|
Description: "Quantization method (e.g. q4_0, q5_1, q8_0)",
|
||||||
|
Component: "select",
|
||||||
|
Options: QuantizationOptions,
|
||||||
|
Advanced: true,
|
||||||
|
Order: 20,
|
||||||
|
},
|
||||||
|
"flash_attention": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "Flash Attention",
|
||||||
|
Description: "Enable flash attention for faster inference",
|
||||||
|
Component: "input",
|
||||||
|
Advanced: true,
|
||||||
|
Order: 21,
|
||||||
|
},
|
||||||
|
"cache_type_k": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "KV Cache Type (K)",
|
||||||
|
Description: "Quantization type for key cache (e.g. f16, q8_0, q4_0)",
|
||||||
|
Component: "select",
|
||||||
|
Options: CacheTypeOptions,
|
||||||
|
VRAMImpact: true,
|
||||||
|
Advanced: true,
|
||||||
|
Order: 22,
|
||||||
|
},
|
||||||
|
"cache_type_v": {
|
||||||
|
Section: "llm",
|
||||||
|
Label: "KV Cache Type (V)",
|
||||||
|
Description: "Quantization type for value cache",
|
||||||
|
Component: "select",
|
||||||
|
Options: CacheTypeOptions,
|
||||||
|
VRAMImpact: true,
|
||||||
|
Advanced: true,
|
||||||
|
Order: 23,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- Parameters ---
|
||||||
|
"parameters.temperature": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Temperature",
|
||||||
|
Description: "Sampling temperature (higher = more creative, lower = more deterministic)",
|
||||||
|
Component: "slider",
|
||||||
|
Min: f64(0),
|
||||||
|
Max: f64(2),
|
||||||
|
Step: f64(0.05),
|
||||||
|
Order: 30,
|
||||||
|
},
|
||||||
|
"parameters.top_p": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Top P",
|
||||||
|
Description: "Nucleus sampling threshold",
|
||||||
|
Component: "slider",
|
||||||
|
Min: f64(0),
|
||||||
|
Max: f64(1),
|
||||||
|
Step: f64(0.01),
|
||||||
|
Order: 31,
|
||||||
|
},
|
||||||
|
"parameters.top_k": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Top K",
|
||||||
|
Description: "Top-K sampling: consider only the K most likely tokens",
|
||||||
|
Component: "number",
|
||||||
|
Min: f64(0),
|
||||||
|
Order: 32,
|
||||||
|
},
|
||||||
|
"parameters.max_tokens": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Max Tokens",
|
||||||
|
Description: "Maximum number of tokens to generate (0 = unlimited)",
|
||||||
|
Component: "number",
|
||||||
|
Min: f64(0),
|
||||||
|
Order: 33,
|
||||||
|
},
|
||||||
|
"parameters.repeat_penalty": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Repeat Penalty",
|
||||||
|
Description: "Penalize repeated tokens (1.0 = no penalty)",
|
||||||
|
Component: "number",
|
||||||
|
Min: f64(0),
|
||||||
|
Advanced: true,
|
||||||
|
Order: 34,
|
||||||
|
},
|
||||||
|
"parameters.seed": {
|
||||||
|
Section: "parameters",
|
||||||
|
Label: "Seed",
|
||||||
|
Description: "Random seed (-1 = random)",
|
||||||
|
Component: "number",
|
||||||
|
Advanced: true,
|
||||||
|
Order: 35,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- Templates ---
|
||||||
|
"template.chat": {
|
||||||
|
Section: "templates",
|
||||||
|
Label: "Chat Template",
|
||||||
|
Description: "Go template for chat completion requests",
|
||||||
|
Component: "code-editor",
|
||||||
|
Order: 40,
|
||||||
|
},
|
||||||
|
"template.chat_message": {
|
||||||
|
Section: "templates",
|
||||||
|
Label: "Chat Message Template",
|
||||||
|
Description: "Go template for individual chat messages",
|
||||||
|
Component: "code-editor",
|
||||||
|
Order: 41,
|
||||||
|
},
|
||||||
|
"template.completion": {
|
||||||
|
Section: "templates",
|
||||||
|
Label: "Completion Template",
|
||||||
|
Description: "Go template for completion requests",
|
||||||
|
Component: "code-editor",
|
||||||
|
Order: 42,
|
||||||
|
},
|
||||||
|
"template.use_tokenizer_template": {
|
||||||
|
Section: "templates",
|
||||||
|
Label: "Use Tokenizer Template",
|
||||||
|
Description: "Use the chat template from the model's tokenizer config",
|
||||||
|
Order: 43,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- Pipeline ---
|
||||||
|
"pipeline.llm": {
|
||||||
|
Section: "pipeline",
|
||||||
|
Label: "LLM Model",
|
||||||
|
Description: "Model to use for LLM inference in the pipeline",
|
||||||
|
Component: "model-select",
|
||||||
|
AutocompleteProvider: ProviderModelsChat,
|
||||||
|
Order: 60,
|
||||||
|
},
|
||||||
|
"pipeline.tts": {
|
||||||
|
Section: "pipeline",
|
||||||
|
Label: "TTS Model",
|
||||||
|
Description: "Model to use for text-to-speech in the pipeline",
|
||||||
|
Component: "model-select",
|
||||||
|
AutocompleteProvider: ProviderModelsTTS,
|
||||||
|
Order: 61,
|
||||||
|
},
|
||||||
|
"pipeline.transcription": {
|
||||||
|
Section: "pipeline",
|
||||||
|
Label: "Transcription Model",
|
||||||
|
Description: "Model to use for speech-to-text in the pipeline",
|
||||||
|
Component: "model-select",
|
||||||
|
AutocompleteProvider: ProviderModelsTranscript,
|
||||||
|
Order: 62,
|
||||||
|
},
|
||||||
|
"pipeline.vad": {
|
||||||
|
Section: "pipeline",
|
||||||
|
Label: "VAD Model",
|
||||||
|
Description: "Model to use for voice activity detection in the pipeline",
|
||||||
|
Component: "model-select",
|
||||||
|
AutocompleteProvider: ProviderModelsVAD,
|
||||||
|
Order: 63,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- Functions ---
|
||||||
|
"function.grammar.parallel_calls": {
|
||||||
|
Section: "functions",
|
||||||
|
Label: "Parallel Calls",
|
||||||
|
Description: "Allow the LLM to return multiple function calls in one response",
|
||||||
|
Order: 70,
|
||||||
|
},
|
||||||
|
"function.grammar.mixed_mode": {
|
||||||
|
Section: "functions",
|
||||||
|
Label: "Mixed Mode",
|
||||||
|
Description: "Allow the LLM to return both text and function calls",
|
||||||
|
Order: 71,
|
||||||
|
},
|
||||||
|
"function.grammar.disable": {
|
||||||
|
Section: "functions",
|
||||||
|
Label: "Disable Grammar",
|
||||||
|
Description: "Disable grammar-constrained generation for function calls",
|
||||||
|
Advanced: true,
|
||||||
|
Order: 72,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- TTS ---
|
||||||
|
"tts.voice": {
|
||||||
|
Section: "tts",
|
||||||
|
Label: "Voice",
|
||||||
|
Description: "Default voice for TTS output",
|
||||||
|
Component: "input",
|
||||||
|
Order: 90,
|
||||||
|
},
|
||||||
|
|
||||||
|
// --- Diffusers ---
|
||||||
|
"diffusers.pipeline_type": {
|
||||||
|
Section: "diffusers",
|
||||||
|
Label: "Pipeline Type",
|
||||||
|
Description: "Diffusers pipeline type (e.g. StableDiffusionPipeline)",
|
||||||
|
Component: "select",
|
||||||
|
Options: DiffusersPipelineOptions,
|
||||||
|
Order: 80,
|
||||||
|
},
|
||||||
|
"diffusers.scheduler_type": {
|
||||||
|
Section: "diffusers",
|
||||||
|
Label: "Scheduler Type",
|
||||||
|
Description: "Noise scheduler type",
|
||||||
|
Component: "select",
|
||||||
|
Options: DiffusersSchedulerOptions,
|
||||||
|
Order: 81,
|
||||||
|
},
|
||||||
|
"diffusers.cuda": {
|
||||||
|
Section: "diffusers",
|
||||||
|
Label: "CUDA",
|
||||||
|
Description: "Enable CUDA for diffusers",
|
||||||
|
Order: 82,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
83
core/config/meta/types.go
Normal file
83
core/config/meta/types.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package meta
|
||||||
|
|
||||||
|
// FieldMeta describes a single configuration field for UI rendering and agent discovery.
|
||||||
|
type FieldMeta struct {
|
||||||
|
Path string `json:"path"` // dot-path: "context_size", "function.grammar.parallel_calls"
|
||||||
|
YAMLKey string `json:"yaml_key"` // leaf yaml key
|
||||||
|
GoType string `json:"go_type"` // "*int", "string", "[]string"
|
||||||
|
UIType string `json:"ui_type"` // "string", "int", "float", "bool", "[]string", "map", "object"
|
||||||
|
Pointer bool `json:"pointer,omitempty"` // true = nil means "not set"
|
||||||
|
Section string `json:"section"` // "general", "llm", "templates", etc.
|
||||||
|
Label string `json:"label"` // human-readable label
|
||||||
|
Description string `json:"description,omitempty"` // help text
|
||||||
|
Component string `json:"component"` // "input", "number", "toggle", "select", "slider", etc.
|
||||||
|
Placeholder string `json:"placeholder,omitempty"`
|
||||||
|
Default any `json:"default,omitempty"`
|
||||||
|
Min *float64 `json:"min,omitempty"`
|
||||||
|
Max *float64 `json:"max,omitempty"`
|
||||||
|
Step *float64 `json:"step,omitempty"`
|
||||||
|
Options []FieldOption `json:"options,omitempty"`
|
||||||
|
|
||||||
|
AutocompleteProvider string `json:"autocomplete_provider,omitempty"` // "backends", "models:chat", etc.
|
||||||
|
VRAMImpact bool `json:"vram_impact,omitempty"`
|
||||||
|
Advanced bool `json:"advanced,omitempty"`
|
||||||
|
Order int `json:"order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FieldOption represents a choice in a select/enum field.
|
||||||
|
type FieldOption struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section groups related fields in the UI.
|
||||||
|
type Section struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
Order int `json:"order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigMetadata is the top-level response for the metadata API.
|
||||||
|
type ConfigMetadata struct {
|
||||||
|
Sections []Section `json:"sections"`
|
||||||
|
Fields []FieldMeta `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FieldMetaOverride holds registry overrides that are merged on top of
|
||||||
|
// the reflection-discovered defaults. Only non-zero fields override.
|
||||||
|
type FieldMetaOverride struct {
|
||||||
|
Section string
|
||||||
|
Label string
|
||||||
|
Description string
|
||||||
|
Component string
|
||||||
|
Placeholder string
|
||||||
|
Default any
|
||||||
|
Min *float64
|
||||||
|
Max *float64
|
||||||
|
Step *float64
|
||||||
|
Options []FieldOption
|
||||||
|
AutocompleteProvider string
|
||||||
|
VRAMImpact bool
|
||||||
|
Advanced bool
|
||||||
|
Order int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSections defines the well-known config sections in display order.
|
||||||
|
func DefaultSections() []Section {
|
||||||
|
return []Section{
|
||||||
|
{ID: "general", Label: "General", Icon: "settings", Order: 0},
|
||||||
|
{ID: "llm", Label: "LLM", Icon: "cpu", Order: 10},
|
||||||
|
{ID: "parameters", Label: "Parameters", Icon: "sliders", Order: 20},
|
||||||
|
{ID: "templates", Label: "Templates", Icon: "file-text", Order: 30},
|
||||||
|
{ID: "functions", Label: "Functions / Tools", Icon: "tool", Order: 40},
|
||||||
|
{ID: "reasoning", Label: "Reasoning", Icon: "brain", Order: 45},
|
||||||
|
{ID: "diffusers", Label: "Diffusers", Icon: "image", Order: 50},
|
||||||
|
{ID: "tts", Label: "TTS", Icon: "volume-2", Order: 55},
|
||||||
|
{ID: "pipeline", Label: "Pipeline", Icon: "git-merge", Order: 60},
|
||||||
|
{ID: "grpc", Label: "gRPC", Icon: "server", Order: 65},
|
||||||
|
{ID: "agent", Label: "Agent", Icon: "bot", Order: 70},
|
||||||
|
{ID: "mcp", Label: "MCP", Icon: "plug", Order: 75},
|
||||||
|
{ID: "other", Label: "Other", Icon: "more-horizontal", Order: 100},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -77,6 +77,8 @@ type ModelConfig struct {
|
|||||||
|
|
||||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||||
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
||||||
|
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
|
||||||
|
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
|
||||||
|
|
||||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||||
@@ -548,6 +550,16 @@ func (c *ModelConfig) GetModelTemplate() string {
|
|||||||
return c.modelTemplate
|
return c.modelTemplate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsDisabled returns true if the model is disabled
|
||||||
|
func (c *ModelConfig) IsDisabled() bool {
|
||||||
|
return c.Disabled != nil && *c.Disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPinned returns true if the model is pinned (excluded from idle unloading and eviction)
|
||||||
|
func (c *ModelConfig) IsPinned() bool {
|
||||||
|
return c.Pinned != nil && *c.Pinned
|
||||||
|
}
|
||||||
|
|
||||||
type ModelConfigUsecase int
|
type ModelConfigUsecase int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -705,7 +717,8 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||||
if c.Backend != "rfdetr" {
|
detectionBackends := []string{"rfdetr", "sam3-cpp"}
|
||||||
|
if !slices.Contains(detectionBackends, c.Backend) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type RuntimeSettings struct {
|
|||||||
// Backend management
|
// Backend management
|
||||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||||
|
AutoUpgradeBackends *bool `json:"auto_upgrade_backends,omitempty"` // Automatically upgrade backends when new versions are detected
|
||||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||||
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
|
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
|
||||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||||
|
|||||||
@@ -20,12 +20,19 @@ type BackendMetadata struct {
|
|||||||
GalleryURL string `json:"gallery_url,omitempty"`
|
GalleryURL string `json:"gallery_url,omitempty"`
|
||||||
// InstalledAt is the timestamp when the backend was installed
|
// InstalledAt is the timestamp when the backend was installed
|
||||||
InstalledAt string `json:"installed_at,omitempty"`
|
InstalledAt string `json:"installed_at,omitempty"`
|
||||||
|
// Version is the version of the backend at install time
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
// URI is the original URI used to install the backend
|
||||||
|
URI string `json:"uri,omitempty"`
|
||||||
|
// Digest is the OCI image digest at install time (for upgrade detection)
|
||||||
|
Digest string `json:"digest,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GalleryBackend struct {
|
type GalleryBackend struct {
|
||||||
Metadata `json:",inline" yaml:",inline"`
|
Metadata `json:",inline" yaml:",inline"`
|
||||||
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
||||||
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
||||||
|
Version string `json:"version,omitempty" yaml:"version,omitempty"`
|
||||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
||||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -71,6 +78,10 @@ func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if systemState.CapabilityFilterDisabled() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Meta backends are compatible if the system capability matches one of the keys
|
// Meta backends are compatible if the system capability matches one of the keys
|
||||||
if m.IsMeta() {
|
if m.IsMeta() {
|
||||||
capability := systemState.Capability(m.CapabilitiesMap)
|
capability := systemState.Capability(m.CapabilitiesMap)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/oci"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
cp "github.com/otiai10/copy"
|
cp "github.com/otiai10/copy"
|
||||||
@@ -158,6 +159,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
|
|||||||
Name: name,
|
Name: name,
|
||||||
GalleryURL: backend.Gallery.URL,
|
GalleryURL: backend.Gallery.URL,
|
||||||
InstalledAt: time.Now().Format(time.RFC3339),
|
InstalledAt: time.Now().Format(time.RFC3339),
|
||||||
|
Version: bestBackend.Version,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil {
|
if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil {
|
||||||
@@ -279,6 +281,18 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
Name: name,
|
Name: name,
|
||||||
GalleryURL: config.Gallery.URL,
|
GalleryURL: config.Gallery.URL,
|
||||||
InstalledAt: time.Now().Format(time.RFC3339),
|
InstalledAt: time.Now().Format(time.RFC3339),
|
||||||
|
Version: config.Version,
|
||||||
|
URI: string(uri),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record the OCI digest for upgrade detection (non-fatal on failure)
|
||||||
|
if uri.LooksLikeOCI() {
|
||||||
|
digest, digestErr := oci.GetImageDigest(string(uri), "", nil, nil)
|
||||||
|
if digestErr != nil {
|
||||||
|
xlog.Warn("Failed to get OCI image digest for backend", "uri", string(uri), "error", digestErr)
|
||||||
|
} else {
|
||||||
|
metadata.Digest = digest
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Alias != "" {
|
if config.Alias != "" {
|
||||||
@@ -300,14 +314,29 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
|||||||
|
|
||||||
backend, ok := backends.Get(name)
|
backend, ok := backends.Get(name)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
// Not found by direct key — try matching by gallery name (metadata.Name)
|
||||||
|
// The UI may send gallery-style names like "localai@llama-cpp" which
|
||||||
|
// don't match the directory-based keys used in the backends map.
|
||||||
|
for _, b := range backends {
|
||||||
|
if b.Metadata != nil && b.Metadata.Name == name && !b.IsMeta {
|
||||||
|
backend = b
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if backend.IsSystem {
|
if backend.IsSystem {
|
||||||
return fmt.Errorf("system backend %q cannot be deleted", name)
|
return fmt.Errorf("system backend %q cannot be deleted", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
|
// Use the backend's actual Name (directory key) for path resolution,
|
||||||
|
// not the caller-supplied name which may be a gallery-style name.
|
||||||
|
dirName := backend.Name
|
||||||
|
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, dirName)
|
||||||
|
|
||||||
// check if the backend dir exists
|
// check if the backend dir exists
|
||||||
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
|
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
|
||||||
@@ -325,7 +354,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if metadata != nil && metadata.Alias == name {
|
if metadata != nil && (metadata.Alias == name || metadata.Alias == dirName) {
|
||||||
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
|
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
|
||||||
foundBackend = true
|
foundBackend = true
|
||||||
break
|
break
|
||||||
@@ -358,11 +387,13 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SystemBackend struct {
|
type SystemBackend struct {
|
||||||
Name string
|
Name string
|
||||||
RunFile string
|
RunFile string
|
||||||
IsMeta bool
|
IsMeta bool
|
||||||
IsSystem bool
|
IsSystem bool
|
||||||
Metadata *BackendMetadata
|
Metadata *BackendMetadata
|
||||||
|
UpgradeAvailable bool `json:"upgrade_available,omitempty"`
|
||||||
|
AvailableVersion string `json:"available_version,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemBackends map[string]SystemBackend
|
type SystemBackends map[string]SystemBackend
|
||||||
|
|||||||
118
core/gallery/backends_version_test.go
Normal file
118
core/gallery/backends_version_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package gallery_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Backend versioning", func() {
|
||||||
|
var tempDir string
|
||||||
|
var systemState *system.SystemState
|
||||||
|
var modelLoader *model.ModelLoader
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tempDir, err = os.MkdirTemp("", "gallery-version-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
systemState, err = system.GetSystemState(
|
||||||
|
system.WithBackendPath(tempDir),
|
||||||
|
)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
modelLoader = model.NewModelLoader(systemState)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("records version in metadata when installing a backend with a version", func() {
|
||||||
|
// Create a fake backend source directory with a run.sh
|
||||||
|
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer os.RemoveAll(srcDir)
|
||||||
|
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
backend := &gallery.GalleryBackend{}
|
||||||
|
backend.Name = "test-backend"
|
||||||
|
backend.URI = srcDir
|
||||||
|
backend.Version = "1.2.3"
|
||||||
|
|
||||||
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
// Read the metadata file and check version
|
||||||
|
metadataPath := filepath.Join(tempDir, "test-backend", "metadata.json")
|
||||||
|
data, err := os.ReadFile(metadataPath)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
err = json.Unmarshal(data, &metadata)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(metadata["version"]).To(Equal("1.2.3"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("records URI in metadata", func() {
|
||||||
|
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer os.RemoveAll(srcDir)
|
||||||
|
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
backend := &gallery.GalleryBackend{}
|
||||||
|
backend.Name = "test-backend-uri"
|
||||||
|
backend.URI = srcDir
|
||||||
|
backend.Version = "2.0.0"
|
||||||
|
|
||||||
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
metadataPath := filepath.Join(tempDir, "test-backend-uri", "metadata.json")
|
||||||
|
data, err := os.ReadFile(metadataPath)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
err = json.Unmarshal(data, &metadata)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(metadata["uri"]).To(Equal(srcDir))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("omits version key when version is empty", func() {
|
||||||
|
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer os.RemoveAll(srcDir)
|
||||||
|
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
backend := &gallery.GalleryBackend{}
|
||||||
|
backend.Name = "test-backend-noversion"
|
||||||
|
backend.URI = srcDir
|
||||||
|
// Version intentionally left empty
|
||||||
|
|
||||||
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
metadataPath := filepath.Join(tempDir, "test-backend-noversion", "metadata.json")
|
||||||
|
data, err := os.ReadFile(metadataPath)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
err = json.Unmarshal(data, &metadata)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
// omitempty should exclude the version key entirely
|
||||||
|
_, hasVersion := metadata["version"]
|
||||||
|
Expect(hasVersion).To(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
237
core/gallery/upgrade.go
Normal file
237
core/gallery/upgrade.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
package gallery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/oci"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
cp "github.com/otiai10/copy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpgradeInfo holds details about an available backend upgrade.
|
||||||
|
type UpgradeInfo struct {
|
||||||
|
BackendName string `json:"backend_name"`
|
||||||
|
InstalledVersion string `json:"installed_version"`
|
||||||
|
AvailableVersion string `json:"available_version"`
|
||||||
|
InstalledDigest string `json:"installed_digest,omitempty"`
|
||||||
|
AvailableDigest string `json:"available_digest,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckBackendUpgrades compares installed backends against gallery entries
|
||||||
|
// and returns a map of backend names to UpgradeInfo for those that have
|
||||||
|
// newer versions or different OCI digests available.
|
||||||
|
func CheckBackendUpgrades(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState) (map[string]UpgradeInfo, error) {
|
||||||
|
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list available backends: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
installedBackends, err := ListSystemBackends(systemState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list installed backends: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]UpgradeInfo)
|
||||||
|
|
||||||
|
for _, installed := range installedBackends {
|
||||||
|
// Skip system backends — they are managed outside the gallery
|
||||||
|
if installed.IsSystem {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if installed.Metadata == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find matching gallery entry by metadata name
|
||||||
|
galleryEntry := FindGalleryElement(galleryBackends, installed.Metadata.Name)
|
||||||
|
if galleryEntry == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
installedVersion := installed.Metadata.Version
|
||||||
|
galleryVersion := galleryEntry.Version
|
||||||
|
|
||||||
|
// If both sides have versions, compare them
|
||||||
|
if galleryVersion != "" && installedVersion != "" {
|
||||||
|
if galleryVersion != installedVersion {
|
||||||
|
result[installed.Metadata.Name] = UpgradeInfo{
|
||||||
|
BackendName: installed.Metadata.Name,
|
||||||
|
InstalledVersion: installedVersion,
|
||||||
|
AvailableVersion: galleryVersion,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Versions match — no upgrade needed
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gallery has a version but installed doesn't — this happens for backends
|
||||||
|
// installed before version tracking was added. Flag as upgradeable so
|
||||||
|
// users can re-install to pick up version metadata.
|
||||||
|
if galleryVersion != "" && installedVersion == "" {
|
||||||
|
result[installed.Metadata.Name] = UpgradeInfo{
|
||||||
|
BackendName: installed.Metadata.Name,
|
||||||
|
InstalledVersion: "",
|
||||||
|
AvailableVersion: galleryVersion,
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to OCI digest comparison when versions are unavailable
|
||||||
|
if downloader.URI(galleryEntry.URI).LooksLikeOCI() {
|
||||||
|
remoteDigest, err := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Warn("Failed to get remote OCI digest for upgrade check", "backend", installed.Metadata.Name, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If we have a stored digest, compare; otherwise any remote digest
|
||||||
|
// means we can't confirm we're up to date — flag as upgradeable
|
||||||
|
if installed.Metadata.Digest == "" || remoteDigest != installed.Metadata.Digest {
|
||||||
|
result[installed.Metadata.Name] = UpgradeInfo{
|
||||||
|
BackendName: installed.Metadata.Name,
|
||||||
|
InstalledDigest: installed.Metadata.Digest,
|
||||||
|
AvailableDigest: remoteDigest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// No version info and non-OCI URI — cannot determine, skip
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpgradeBackend upgrades a single backend to the latest gallery version using
|
||||||
|
// an atomic swap with backup-based rollback on failure.
|
||||||
|
func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, galleries []config.Gallery, backendName string, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
// Look up the installed backend
|
||||||
|
installedBackends, err := ListSystemBackends(systemState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list installed backends: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
installed, ok := installedBackends.Get(backendName)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("backend %q: %w", backendName, ErrBackendNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
if installed.IsSystem {
|
||||||
|
return fmt.Errorf("system backend %q cannot be upgraded via gallery", backendName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a meta backend, recursively upgrade the concrete backend it points to
|
||||||
|
if installed.Metadata != nil && installed.Metadata.MetaBackendFor != "" {
|
||||||
|
xlog.Info("Meta backend detected, upgrading concrete backend", "meta", backendName, "concrete", installed.Metadata.MetaBackendFor)
|
||||||
|
return UpgradeBackend(ctx, systemState, modelLoader, galleries, installed.Metadata.MetaBackendFor, downloadStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the gallery entry
|
||||||
|
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list available backends: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
galleryEntry := FindGalleryElement(galleryBackends, backendName)
|
||||||
|
if galleryEntry == nil {
|
||||||
|
return fmt.Errorf("no gallery entry found for backend %q", backendName)
|
||||||
|
}
|
||||||
|
|
||||||
|
backendPath := filepath.Join(systemState.Backend.BackendsPath, backendName)
|
||||||
|
tmpPath := backendPath + ".upgrade-tmp"
|
||||||
|
backupPath := backendPath + ".backup"
|
||||||
|
|
||||||
|
// Clean up any stale tmp/backup dirs from prior attempts
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
os.RemoveAll(backupPath)
|
||||||
|
|
||||||
|
// Step 1: Download the new backend into the tmp directory
|
||||||
|
if err := os.MkdirAll(tmpPath, 0750); err != nil {
|
||||||
|
return fmt.Errorf("failed to create upgrade tmp dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := downloader.URI(galleryEntry.URI)
|
||||||
|
if uri.LooksLikeDir() {
|
||||||
|
if err := cp.Copy(string(uri), tmpPath); err != nil {
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
return fmt.Errorf("failed to copy backend from directory: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := uri.DownloadFileWithContext(ctx, tmpPath, "", 1, 1, downloadStatus); err != nil {
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
return fmt.Errorf("failed to download backend: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Validate — check that run.sh exists in the new content
|
||||||
|
newRunFile := filepath.Join(tmpPath, runFile)
|
||||||
|
if _, err := os.Stat(newRunFile); os.IsNotExist(err) {
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
return fmt.Errorf("upgrade validation failed: run.sh not found in new backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Atomic swap — rename current to backup, then tmp to current
|
||||||
|
if err := os.Rename(backendPath, backupPath); err != nil {
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
return fmt.Errorf("failed to move current backend to backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(tmpPath, backendPath); err != nil {
|
||||||
|
// Restore backup on failure
|
||||||
|
xlog.Error("Failed to move new backend into place, restoring backup", "error", err)
|
||||||
|
if restoreErr := os.Rename(backupPath, backendPath); restoreErr != nil {
|
||||||
|
xlog.Error("Failed to restore backup", "error", restoreErr)
|
||||||
|
}
|
||||||
|
os.RemoveAll(tmpPath)
|
||||||
|
return fmt.Errorf("failed to move new backend into place: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Write updated metadata, preserving alias from old metadata
|
||||||
|
var oldAlias string
|
||||||
|
if installed.Metadata != nil {
|
||||||
|
oldAlias = installed.Metadata.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
newMetadata := &BackendMetadata{
|
||||||
|
Name: backendName,
|
||||||
|
Version: galleryEntry.Version,
|
||||||
|
URI: galleryEntry.URI,
|
||||||
|
InstalledAt: time.Now().Format(time.RFC3339),
|
||||||
|
Alias: oldAlias,
|
||||||
|
}
|
||||||
|
|
||||||
|
if galleryEntry.Gallery.URL != "" {
|
||||||
|
newMetadata.GalleryURL = galleryEntry.Gallery.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record OCI digest if applicable (non-fatal on failure)
|
||||||
|
if uri.LooksLikeOCI() {
|
||||||
|
digest, digestErr := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||||
|
if digestErr != nil {
|
||||||
|
xlog.Warn("Failed to get OCI image digest after upgrade", "uri", galleryEntry.URI, "error", digestErr)
|
||||||
|
} else {
|
||||||
|
newMetadata.Digest = digest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeBackendMetadata(backendPath, newMetadata); err != nil {
|
||||||
|
// Metadata write failure is not worth rolling back the entire upgrade
|
||||||
|
xlog.Error("Failed to write metadata after upgrade", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Re-register backends so the model loader picks up any changes
|
||||||
|
if err := RegisterBackends(systemState, modelLoader); err != nil {
|
||||||
|
xlog.Warn("Failed to re-register backends after upgrade", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 6: Remove backup
|
||||||
|
os.RemoveAll(backupPath)
|
||||||
|
|
||||||
|
xlog.Info("Backend upgraded successfully", "backend", backendName, "version", galleryEntry.Version)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
219
core/gallery/upgrade_test.go
Normal file
219
core/gallery/upgrade_test.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package gallery_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
. "github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Upgrade Detection and Execution", func() {
|
||||||
|
var (
|
||||||
|
tempDir string
|
||||||
|
backendsPath string
|
||||||
|
galleryPath string
|
||||||
|
systemState *system.SystemState
|
||||||
|
galleries []config.Gallery
|
||||||
|
)
|
||||||
|
|
||||||
|
// installBackendWithVersion creates a fake installed backend directory with
|
||||||
|
// the given name, version, and optional run.sh content.
|
||||||
|
installBackendWithVersion := func(name, version string, runContent ...string) {
|
||||||
|
dir := filepath.Join(backendsPath, name)
|
||||||
|
Expect(os.MkdirAll(dir, 0750)).To(Succeed())
|
||||||
|
|
||||||
|
content := "#!/bin/sh\necho ok"
|
||||||
|
if len(runContent) > 0 {
|
||||||
|
content = runContent[0]
|
||||||
|
}
|
||||||
|
Expect(os.WriteFile(filepath.Join(dir, "run.sh"), []byte(content), 0755)).To(Succeed())
|
||||||
|
|
||||||
|
metadata := BackendMetadata{
|
||||||
|
Name: name,
|
||||||
|
Version: version,
|
||||||
|
InstalledAt: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
data, err := json.MarshalIndent(metadata, "", " ")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(os.WriteFile(filepath.Join(dir, "metadata.json"), data, 0644)).To(Succeed())
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeGalleryYAML writes a gallery YAML file with the given backends.
|
||||||
|
writeGalleryYAML := func(backends []GalleryBackend) {
|
||||||
|
data, err := yaml.Marshal(backends)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(os.WriteFile(galleryPath, data, 0644)).To(Succeed())
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tempDir, err = os.MkdirTemp("", "upgrade-test-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
backendsPath = tempDir
|
||||||
|
|
||||||
|
galleryPath = filepath.Join(tempDir, "gallery.yaml")
|
||||||
|
|
||||||
|
// Write a default empty gallery
|
||||||
|
writeGalleryYAML([]GalleryBackend{})
|
||||||
|
|
||||||
|
galleries = []config.Gallery{
|
||||||
|
{
|
||||||
|
Name: "test-gallery",
|
||||||
|
URL: "file://" + galleryPath,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
systemState, err = system.GetSystemState(
|
||||||
|
system.WithBackendPath(backendsPath),
|
||||||
|
)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("CheckBackendUpgrades", func() {
|
||||||
|
It("should detect upgrade when gallery version differs from installed version", func() {
|
||||||
|
// Install a backend at v1.0.0
|
||||||
|
installBackendWithVersion("my-backend", "1.0.0")
|
||||||
|
|
||||||
|
// Gallery advertises v2.0.0
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: filepath.Join(tempDir, "some-source"),
|
||||||
|
Version: "2.0.0",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(upgrades).To(HaveKey("my-backend"))
|
||||||
|
Expect(upgrades["my-backend"].InstalledVersion).To(Equal("1.0.0"))
|
||||||
|
Expect(upgrades["my-backend"].AvailableVersion).To(Equal("2.0.0"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should NOT flag upgrade when versions match", func() {
|
||||||
|
installBackendWithVersion("my-backend", "2.0.0")
|
||||||
|
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: filepath.Join(tempDir, "some-source"),
|
||||||
|
Version: "2.0.0",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(upgrades).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should skip backends without version info and without OCI digest", func() {
|
||||||
|
// Install without version
|
||||||
|
installBackendWithVersion("my-backend", "")
|
||||||
|
|
||||||
|
// Gallery also without version
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: filepath.Join(tempDir, "some-source"),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(upgrades).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("UpgradeBackend", func() {
|
||||||
|
It("should replace backend directory and update metadata", func() {
|
||||||
|
// Install v1
|
||||||
|
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||||
|
|
||||||
|
// Create a source directory with v2 content
|
||||||
|
srcDir := filepath.Join(tempDir, "v2-source")
|
||||||
|
Expect(os.MkdirAll(srcDir, 0750)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho v2"), 0755)).To(Succeed())
|
||||||
|
|
||||||
|
// Gallery points to the v2 source dir
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: srcDir,
|
||||||
|
Version: "2.0.0",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ml := model.NewModelLoader(systemState)
|
||||||
|
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
// Verify run.sh was updated
|
||||||
|
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(string(content)).To(Equal("#!/bin/sh\necho v2"))
|
||||||
|
|
||||||
|
// Verify metadata was updated
|
||||||
|
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
var meta BackendMetadata
|
||||||
|
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||||
|
Expect(meta.Version).To(Equal("2.0.0"))
|
||||||
|
Expect(meta.Name).To(Equal("my-backend"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should restore backup on failure", func() {
|
||||||
|
// Install v1
|
||||||
|
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||||
|
|
||||||
|
// Gallery points to a nonexistent path (no run.sh will be found)
|
||||||
|
nonExistentDir := filepath.Join(tempDir, "does-not-exist")
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: nonExistentDir,
|
||||||
|
Version: "2.0.0",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ml := model.NewModelLoader(systemState)
|
||||||
|
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
|
||||||
|
// Verify v1 is still intact
|
||||||
|
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(string(content)).To(Equal("#!/bin/sh\necho v1"))
|
||||||
|
|
||||||
|
// Verify metadata still says v1
|
||||||
|
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
var meta BackendMetadata
|
||||||
|
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||||
|
Expect(meta.Version).To(Equal("1.0.0"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -52,9 +52,42 @@ var quietPaths = []string{"/api/operations", "/api/resources", "/healthz", "/rea
|
|||||||
// @license.name MIT
|
// @license.name MIT
|
||||||
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
|
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
|
||||||
// @BasePath /
|
// @BasePath /
|
||||||
|
// @schemes http https
|
||||||
// @securityDefinitions.apikey BearerAuth
|
// @securityDefinitions.apikey BearerAuth
|
||||||
// @in header
|
// @in header
|
||||||
// @name Authorization
|
// @name Authorization
|
||||||
|
// @tag.name inference
|
||||||
|
// @tag.description Chat completions, text completions, edits, and responses (OpenAI-compatible)
|
||||||
|
// @tag.name embeddings
|
||||||
|
// @tag.description Vector embeddings (OpenAI-compatible)
|
||||||
|
// @tag.name audio
|
||||||
|
// @tag.description Text-to-speech, transcription, voice activity detection, sound generation
|
||||||
|
// @tag.name images
|
||||||
|
// @tag.description Image generation and inpainting
|
||||||
|
// @tag.name video
|
||||||
|
// @tag.description Video generation from prompts
|
||||||
|
// @tag.name detection
|
||||||
|
// @tag.description Object detection in images
|
||||||
|
// @tag.name tokenize
|
||||||
|
// @tag.description Tokenization and token metrics
|
||||||
|
// @tag.name models
|
||||||
|
// @tag.description Model gallery browsing, installation, deletion, and listing
|
||||||
|
// @tag.name backends
|
||||||
|
// @tag.description Backend gallery browsing, installation, deletion, and listing
|
||||||
|
// @tag.name config
|
||||||
|
// @tag.description Model configuration metadata, autocomplete, PATCH updates, VRAM estimation
|
||||||
|
// @tag.name monitoring
|
||||||
|
// @tag.description Prometheus metrics, backend status, system information
|
||||||
|
// @tag.name mcp
|
||||||
|
// @tag.description Model Context Protocol — tool-augmented chat with MCP servers
|
||||||
|
// @tag.name agent-jobs
|
||||||
|
// @tag.description Agent task and job management
|
||||||
|
// @tag.name p2p
|
||||||
|
// @tag.description Peer-to-peer networking nodes and tokens
|
||||||
|
// @tag.name rerank
|
||||||
|
// @tag.description Document reranking
|
||||||
|
// @tag.name instructions
|
||||||
|
// @tag.description API instruction discovery — browse instruction areas and get endpoint guides
|
||||||
|
|
||||||
func API(application *application.Application) (*echo.Echo, error) {
|
func API(application *application.Application) (*echo.Echo, error) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
@@ -358,9 +391,13 @@ func API(application *application.Application) (*echo.Echo, error) {
|
|||||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||||
|
routes.RegisterOllamaRoutes(e, requestExtractor, application)
|
||||||
|
if application.ApplicationConfig().OllamaAPIRootEndpoint {
|
||||||
|
routes.RegisterOllamaRootEndpoint(e)
|
||||||
|
}
|
||||||
if !application.ApplicationConfig().DisableWebUI {
|
if !application.ApplicationConfig().DisableWebUI {
|
||||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||||
|
|
||||||
// Serve React SPA from / with SPA fallback via 404 handler
|
// Serve React SPA from / with SPA fallback via 404 handler
|
||||||
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")
|
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")
|
||||||
|
|||||||
@@ -956,8 +956,7 @@ parameters:
|
|||||||
It("returns the models list", func() {
|
It("returns the models list", func() {
|
||||||
models, err := client.ListModels(context.TODO())
|
models, err := client.ListModels(context.TODO())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// A model called "bert" can be present in the model directory depending on the order of the tests
|
Expect(len(models.Models)).To(BeNumerically(">=", 7))
|
||||||
Expect(len(models.Models)).To(BeNumerically(">=", 8))
|
|
||||||
})
|
})
|
||||||
It("can generate completions via ggml", func() {
|
It("can generate completions via ggml", func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
@@ -979,6 +978,42 @@ parameters:
|
|||||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("does not duplicate the first content token in streaming chat completions", Label("llama-gguf", "llama-gguf-stream"), func() {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
Skip("test supported only on linux")
|
||||||
|
}
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{
|
||||||
|
Model: "testmodel.ggml",
|
||||||
|
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
var contentParts []string
|
||||||
|
for {
|
||||||
|
chunk, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if len(chunk.Choices) > 0 {
|
||||||
|
delta := chunk.Choices[0].Delta.Content
|
||||||
|
if delta != "" {
|
||||||
|
contentParts = append(contentParts, delta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(contentParts).ToNot(BeEmpty(), "Expected streaming content tokens")
|
||||||
|
// The first content token should appear exactly once.
|
||||||
|
// A bug in grpc-server.cpp caused the role-init array element
|
||||||
|
// to get the same ChatDelta stamped, duplicating the first token.
|
||||||
|
if len(contentParts) >= 2 {
|
||||||
|
Expect(contentParts[0]).ToNot(Equal(contentParts[1]),
|
||||||
|
"First content token was duplicated: %v", contentParts[:2])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
It("returns logprobs in chat completions when requested", func() {
|
It("returns logprobs in chat completions when requested", func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test only on linux")
|
Skip("test only on linux")
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package anthropic
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@@ -21,6 +23,7 @@ import (
|
|||||||
// MessagesEndpoint is the Anthropic Messages API endpoint
|
// MessagesEndpoint is the Anthropic Messages API endpoint
|
||||||
// https://docs.anthropic.com/claude/reference/messages_post
|
// https://docs.anthropic.com/claude/reference/messages_post
|
||||||
// @Summary Generate a message response for the given messages and model.
|
// @Summary Generate a message response for the given messages and model.
|
||||||
|
// @Tags inference
|
||||||
// @Param request body schema.AnthropicRequest true "query params"
|
// @Param request body schema.AnthropicRequest true "query params"
|
||||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||||
// @Router /v1/messages [post]
|
// @Router /v1/messages [post]
|
||||||
@@ -357,7 +360,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
// Send initial content_block_start event
|
// Send initial content_block_start event
|
||||||
contentBlockStart := schema.AnthropicStreamEvent{
|
contentBlockStart := schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_start",
|
Type: "content_block_start",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
|
ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
|
||||||
}
|
}
|
||||||
sendAnthropicSSE(c, contentBlockStart)
|
sendAnthropicSSE(c, contentBlockStart)
|
||||||
@@ -365,7 +368,33 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
// Collect tool calls for MCP execution
|
// Collect tool calls for MCP execution
|
||||||
var collectedToolCalls []functions.FuncCallResults
|
var collectedToolCalls []functions.FuncCallResults
|
||||||
|
|
||||||
|
// SSE keepalive: send comment pings every 3s until the first token arrives.
|
||||||
|
// This prevents clients (e.g. Claude Code) from timing out while the model loads or processes the prompt.
|
||||||
|
firstTokenReceived := make(chan struct{})
|
||||||
|
keepaliveDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(keepaliveDone)
|
||||||
|
ticker := time.NewTicker(3 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-firstTokenReceived:
|
||||||
|
return
|
||||||
|
case <-c.Request().Context().Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
fmt.Fprintf(c.Response().Writer, "event: ping\ndata: {\"type\": \"ping\"}\n\n")
|
||||||
|
c.Response().Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
firstTokenOnce := sync.Once{}
|
||||||
|
|
||||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||||
|
firstTokenOnce.Do(func() {
|
||||||
|
close(firstTokenReceived)
|
||||||
|
<-keepaliveDone // wait for keepalive goroutine to exit before writing
|
||||||
|
})
|
||||||
accumulatedContent += token
|
accumulatedContent += token
|
||||||
|
|
||||||
if shouldUseFn {
|
if shouldUseFn {
|
||||||
@@ -376,7 +405,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
if !inToolCall && currentBlockIndex == 0 {
|
if !inToolCall && currentBlockIndex == 0 {
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
})
|
})
|
||||||
currentBlockIndex++
|
currentBlockIndex++
|
||||||
inToolCall = true
|
inToolCall = true
|
||||||
@@ -386,7 +415,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
tc := toolCalls[i]
|
tc := toolCalls[i]
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_start",
|
Type: "content_block_start",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
ContentBlock: &schema.AnthropicContentBlock{
|
ContentBlock: &schema.AnthropicContentBlock{
|
||||||
Type: "tool_use",
|
Type: "tool_use",
|
||||||
ID: fmt.Sprintf("toolu_%s_%d", id, i),
|
ID: fmt.Sprintf("toolu_%s_%d", id, i),
|
||||||
@@ -395,7 +424,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
})
|
})
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_delta",
|
Type: "content_block_delta",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
Delta: &schema.AnthropicStreamDelta{
|
Delta: &schema.AnthropicStreamDelta{
|
||||||
Type: "input_json_delta",
|
Type: "input_json_delta",
|
||||||
PartialJSON: tc.Arguments,
|
PartialJSON: tc.Arguments,
|
||||||
@@ -403,7 +432,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
})
|
})
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
})
|
})
|
||||||
currentBlockIndex++
|
currentBlockIndex++
|
||||||
}
|
}
|
||||||
@@ -413,10 +442,10 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !inToolCall {
|
if !inToolCall && token != "" {
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_delta",
|
Type: "content_block_delta",
|
||||||
Index: 0,
|
Index: intPtr(0),
|
||||||
Delta: &schema.AnthropicStreamDelta{
|
Delta: &schema.AnthropicStreamDelta{
|
||||||
Type: "text_delta",
|
Type: "text_delta",
|
||||||
Text: token,
|
Text: token,
|
||||||
@@ -432,6 +461,11 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
openAIReq.Metadata = input.Metadata
|
openAIReq.Metadata = input.Metadata
|
||||||
|
|
||||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||||
|
|
||||||
|
// Stop the keepalive goroutine now that inference is done
|
||||||
|
firstTokenOnce.Do(func() { close(firstTokenReceived) })
|
||||||
|
<-keepaliveDone
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
@@ -444,9 +478,68 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check chat deltas for tool calls
|
// Check chat deltas from C++ autoparser — when active, the raw
|
||||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
// message is cleared and content/tool calls arrive via ChatDeltas.
|
||||||
collectedToolCalls = deltaToolCalls
|
if len(chatDeltas) > 0 {
|
||||||
|
deltaContent := functions.ContentFromChatDeltas(chatDeltas)
|
||||||
|
deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas)
|
||||||
|
|
||||||
|
// Emit text content from ChatDeltas only when the tokenCallback
|
||||||
|
// didn't already stream it (autoparser clears raw text, so
|
||||||
|
// accumulatedContent will be empty in that case).
|
||||||
|
if deltaContent != "" && !inToolCall && accumulatedContent == "" {
|
||||||
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: intPtr(0),
|
||||||
|
Delta: &schema.AnthropicStreamDelta{
|
||||||
|
Type: "text_delta",
|
||||||
|
Text: deltaContent,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit tool_use blocks from ChatDeltas
|
||||||
|
if len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||||
|
collectedToolCalls = deltaToolCalls
|
||||||
|
|
||||||
|
if !inToolCall && currentBlockIndex == 0 {
|
||||||
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: intPtr(currentBlockIndex),
|
||||||
|
})
|
||||||
|
currentBlockIndex++
|
||||||
|
inToolCall = true
|
||||||
|
}
|
||||||
|
for i, tc := range deltaToolCalls {
|
||||||
|
toolCallID := tc.ID
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = fmt.Sprintf("toolu_%s_%d", id, i)
|
||||||
|
}
|
||||||
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "content_block_start",
|
||||||
|
Index: intPtr(currentBlockIndex),
|
||||||
|
ContentBlock: &schema.AnthropicContentBlock{
|
||||||
|
Type: "tool_use",
|
||||||
|
ID: toolCallID,
|
||||||
|
Name: tc.Name,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: intPtr(currentBlockIndex),
|
||||||
|
Delta: &schema.AnthropicStreamDelta{
|
||||||
|
Type: "input_json_delta",
|
||||||
|
PartialJSON: tc.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: intPtr(currentBlockIndex),
|
||||||
|
})
|
||||||
|
currentBlockIndex++
|
||||||
|
toolCallsEmitted++
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MCP streaming tool execution: if we collected MCP tool calls, execute and loop
|
// MCP streaming tool execution: if we collected MCP tool calls, execute and loop
|
||||||
@@ -516,7 +609,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
// Close the text content block
|
// Close the text content block
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
})
|
})
|
||||||
currentBlockIndex++
|
currentBlockIndex++
|
||||||
inToolCall = true
|
inToolCall = true
|
||||||
@@ -528,7 +621,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
}
|
}
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_start",
|
Type: "content_block_start",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
ContentBlock: &schema.AnthropicContentBlock{
|
ContentBlock: &schema.AnthropicContentBlock{
|
||||||
Type: "tool_use",
|
Type: "tool_use",
|
||||||
ID: toolCallID,
|
ID: toolCallID,
|
||||||
@@ -537,7 +630,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
})
|
})
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_delta",
|
Type: "content_block_delta",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
Delta: &schema.AnthropicStreamDelta{
|
Delta: &schema.AnthropicStreamDelta{
|
||||||
Type: "input_json_delta",
|
Type: "input_json_delta",
|
||||||
PartialJSON: fc.Arguments,
|
PartialJSON: fc.Arguments,
|
||||||
@@ -545,7 +638,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
})
|
})
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: currentBlockIndex,
|
Index: intPtr(currentBlockIndex),
|
||||||
})
|
})
|
||||||
currentBlockIndex++
|
currentBlockIndex++
|
||||||
toolCallsEmitted++
|
toolCallsEmitted++
|
||||||
@@ -557,7 +650,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
|||||||
if !inToolCall {
|
if !inToolCall {
|
||||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: 0,
|
Index: intPtr(0),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -598,6 +691,8 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
|||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func intPtr(i int) *int { return &i }
|
||||||
|
|
||||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||||
data, err := json.Marshal(event)
|
data, err := json.Marshal(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
||||||
// @Summary Generates audio from the input text.
|
// @Summary Generates audio from the input text.
|
||||||
|
// @Tags audio
|
||||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||||
// @Success 200 {string} binary "Response"
|
// @Success 200 {string} binary "Response"
|
||||||
// @Router /v1/sound-generation [post]
|
// @Router /v1/sound-generation [post]
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||||
// @Summary Generates audio from the input text.
|
// @Summary Generates audio from the input text.
|
||||||
|
// @Tags audio
|
||||||
// @Param voice-id path string true "Account ID"
|
// @Param voice-id path string true "Account ID"
|
||||||
// @Param request body schema.TTSRequest true "query params"
|
// @Param request body schema.TTSRequest true "query params"
|
||||||
// @Success 200 {string} binary "Response"
|
// @Success 200 {string} binary "Response"
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
// JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/)
|
// JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/)
|
||||||
// @Summary Reranks a list of phrases by relevance to a given text query.
|
// @Summary Reranks a list of phrases by relevance to a given text query.
|
||||||
|
// @Tags rerank
|
||||||
// @Param request body schema.JINARerankRequest true "query params"
|
// @Param request body schema.JINARerankRequest true "query params"
|
||||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||||
// @Router /v1/rerank [post]
|
// @Router /v1/rerank [post]
|
||||||
|
|||||||
@@ -30,6 +30,15 @@ func getJobService(app *application.Application, c echo.Context) *agentpool.Agen
|
|||||||
return jobSvc
|
return jobSvc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateTaskEndpoint creates a new agent task definition.
|
||||||
|
// @Summary Create a new agent task
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body schema.Task true "Task definition"
|
||||||
|
// @Success 201 {object} map[string]string "id"
|
||||||
|
// @Failure 400 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/tasks [post]
|
||||||
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var task schema.Task
|
var task schema.Task
|
||||||
@@ -46,6 +55,17 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateTaskEndpoint updates an existing agent task.
|
||||||
|
// @Summary Update an agent task
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Task ID"
|
||||||
|
// @Param request body schema.Task true "Updated task definition"
|
||||||
|
// @Success 200 {object} map[string]string "message"
|
||||||
|
// @Failure 400 {object} map[string]string "error"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/tasks/{id} [put]
|
||||||
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -65,6 +85,14 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteTaskEndpoint deletes an agent task.
|
||||||
|
// @Summary Delete an agent task
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Task ID"
|
||||||
|
// @Success 200 {object} map[string]string "message"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/tasks/{id} [delete]
|
||||||
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -79,6 +107,13 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListTasksEndpoint lists all agent tasks for the current user.
|
||||||
|
// @Summary List agent tasks
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||||
|
// @Success 200 {object} []schema.Task "tasks"
|
||||||
|
// @Router /api/agent/tasks [get]
|
||||||
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
jobSvc := getJobService(app, c)
|
jobSvc := getJobService(app, c)
|
||||||
@@ -121,6 +156,14 @@ func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTaskEndpoint returns a single agent task by ID.
|
||||||
|
// @Summary Get an agent task
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Task ID"
|
||||||
|
// @Success 200 {object} schema.Task "task"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/tasks/{id} [get]
|
||||||
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -133,6 +176,15 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteJobEndpoint creates and runs a new job for a task.
|
||||||
|
// @Summary Execute an agent job
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body schema.JobExecutionRequest true "Job execution request"
|
||||||
|
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||||
|
// @Failure 400 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/jobs/execute [post]
|
||||||
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var req schema.JobExecutionRequest
|
var req schema.JobExecutionRequest
|
||||||
@@ -168,6 +220,14 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetJobEndpoint returns a single job by ID.
|
||||||
|
// @Summary Get an agent job
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Job ID"
|
||||||
|
// @Success 200 {object} schema.Job "job"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/jobs/{id} [get]
|
||||||
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -180,6 +240,16 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListJobsEndpoint lists jobs, optionally filtered by task or status.
|
||||||
|
// @Summary List agent jobs
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id query string false "Filter by task ID"
|
||||||
|
// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
|
||||||
|
// @Param limit query integer false "Max number of jobs to return"
|
||||||
|
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||||
|
// @Success 200 {object} []schema.Job "jobs"
|
||||||
|
// @Router /api/agent/jobs [get]
|
||||||
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var taskID *string
|
var taskID *string
|
||||||
@@ -241,6 +311,15 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelJobEndpoint cancels a running job.
|
||||||
|
// @Summary Cancel an agent job
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Job ID"
|
||||||
|
// @Success 200 {object} map[string]string "message"
|
||||||
|
// @Failure 400 {object} map[string]string "error"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/jobs/{id}/cancel [post]
|
||||||
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -255,6 +334,14 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteJobEndpoint deletes a job by ID.
|
||||||
|
// @Summary Delete an agent job
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Produce json
|
||||||
|
// @Param id path string true "Job ID"
|
||||||
|
// @Success 200 {object} map[string]string "message"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/jobs/{id} [delete]
|
||||||
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -269,6 +356,17 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteTaskByNameEndpoint looks up a task by name and executes it.
|
||||||
|
// @Summary Execute an agent task by name
|
||||||
|
// @Tags agent-jobs
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param name path string true "Task name"
|
||||||
|
// @Param parameters body object false "Optional template parameters"
|
||||||
|
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||||
|
// @Failure 400 {object} map[string]string "error"
|
||||||
|
// @Failure 404 {object} map[string]string "error"
|
||||||
|
// @Router /api/agent/tasks/{name}/execute [post]
|
||||||
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
name := c.Param("name")
|
name := c.Param("name")
|
||||||
|
|||||||
489
core/http/endpoints/localai/api_instructions.go
Normal file
489
core/http/endpoints/localai/api_instructions.go
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/swagger"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const swaggerDefsPrefix = "#/definitions/"
|
||||||
|
|
||||||
|
// instructionDef is a lightweight instruction definition that maps to swagger tags.
|
||||||
|
type instructionDef struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Intro string `json:"-"` // brief context not in swagger
|
||||||
|
}
|
||||||
|
|
||||||
|
var instructionDefs = []instructionDef{
|
||||||
|
{
|
||||||
|
Name: "chat-inference",
|
||||||
|
Description: "OpenAI-compatible chat completions, text completions, and embeddings",
|
||||||
|
Tags: []string{"inference", "embeddings"},
|
||||||
|
Intro: "Set \"stream\": true for SSE streaming. Supports tool/function calling when the model config has function templates configured.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "audio",
|
||||||
|
Description: "Text-to-speech, voice activity detection, transcription, and sound generation",
|
||||||
|
Tags: []string{"audio"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "images",
|
||||||
|
Description: "Image generation and inpainting",
|
||||||
|
Tags: []string{"images"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "model-management",
|
||||||
|
Description: "Browse the gallery, install, delete, and manage models and backends",
|
||||||
|
Tags: []string{"models", "backends"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "config-management",
|
||||||
|
Description: "Discover, read, and modify model configuration fields with VRAM estimation",
|
||||||
|
Tags: []string{"config"},
|
||||||
|
Intro: "Fields with static options include an \"options\" array in metadata. Fields with dynamic values have an \"autocomplete_provider\" for runtime lookup.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "monitoring",
|
||||||
|
Description: "System metrics, backend status, API and backend traces, backend process logs, and system information",
|
||||||
|
Tags: []string{"monitoring"},
|
||||||
|
Intro: "Includes real-time backend log streaming via WebSocket at /ws/backend-logs/:modelId.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "mcp",
|
||||||
|
Description: "Model Context Protocol — tool-augmented chat with MCP servers",
|
||||||
|
Tags: []string{"mcp"},
|
||||||
|
Intro: "The model's config must define MCP servers. The endpoint handles tool execution automatically.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "agents",
|
||||||
|
Description: "Agent task and job management for CI/automation workflows",
|
||||||
|
Tags: []string{"agent-jobs"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "video",
|
||||||
|
Description: "Video generation from text prompts",
|
||||||
|
Tags: []string{"video"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// swaggerState holds parsed swagger spec data, initialised once.
|
||||||
|
type swaggerState struct {
|
||||||
|
once sync.Once
|
||||||
|
spec map[string]any // full parsed swagger JSON
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var swState swaggerState
|
||||||
|
|
||||||
|
func (s *swaggerState) init() {
|
||||||
|
s.once.Do(func() {
|
||||||
|
var spec map[string]any
|
||||||
|
if err := json.Unmarshal(swagger.SwaggerJSON, &spec); err != nil {
|
||||||
|
xlog.Error("failed to parse embedded swagger spec", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.spec = spec
|
||||||
|
s.ready = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterSwaggerByTags returns a swagger fragment containing only paths whose
|
||||||
|
// operations carry at least one of the given tags, plus the definitions they
|
||||||
|
// reference.
|
||||||
|
func filterSwaggerByTags(spec map[string]any, tags []string) map[string]any {
|
||||||
|
tagSet := make(map[string]bool, len(tags))
|
||||||
|
for _, t := range tags {
|
||||||
|
tagSet[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
paths, _ := spec["paths"].(map[string]any)
|
||||||
|
allDefs, _ := spec["definitions"].(map[string]any)
|
||||||
|
|
||||||
|
filteredPaths := make(map[string]any)
|
||||||
|
for path, methods := range paths {
|
||||||
|
methodMap, ok := methods.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredMethods := make(map[string]any)
|
||||||
|
for method, opRaw := range methodMap {
|
||||||
|
op, ok := opRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opTags, _ := op["tags"].([]any)
|
||||||
|
for _, t := range opTags {
|
||||||
|
if ts, ok := t.(string); ok && tagSet[ts] {
|
||||||
|
filteredMethods[method] = op
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filteredMethods) > 0 {
|
||||||
|
filteredPaths[path] = filteredMethods
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all $ref definitions used by the filtered paths.
|
||||||
|
neededDefs := make(map[string]bool)
|
||||||
|
collectRefs(filteredPaths, neededDefs)
|
||||||
|
|
||||||
|
// Resolve nested refs from definitions themselves.
|
||||||
|
changed := true
|
||||||
|
for changed {
|
||||||
|
changed = false
|
||||||
|
for name := range neededDefs {
|
||||||
|
if def, ok := allDefs[name]; ok {
|
||||||
|
before := len(neededDefs)
|
||||||
|
collectRefs(def, neededDefs)
|
||||||
|
if len(neededDefs) > before {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredDefs := make(map[string]any)
|
||||||
|
for name := range neededDefs {
|
||||||
|
if def, ok := allDefs[name]; ok {
|
||||||
|
filteredDefs[name] = def
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{
|
||||||
|
"paths": filteredPaths,
|
||||||
|
}
|
||||||
|
if len(filteredDefs) > 0 {
|
||||||
|
result["definitions"] = filteredDefs
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectRefs walks a JSON structure and collects all $ref definition names.
|
||||||
|
func collectRefs(v any, refs map[string]bool) {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if ref, ok := val["$ref"].(string); ok {
|
||||||
|
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||||
|
refs[ref[len(swaggerDefsPrefix):]] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, child := range val {
|
||||||
|
collectRefs(child, refs)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, child := range val {
|
||||||
|
collectRefs(child, refs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// swaggerToMarkdown renders a filtered swagger fragment into concise markdown.
|
||||||
|
func swaggerToMarkdown(skillName, intro string, fragment map[string]any) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("# ")
|
||||||
|
b.WriteString(skillName)
|
||||||
|
b.WriteString("\n")
|
||||||
|
if intro != "" {
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(intro)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
paths, _ := fragment["paths"].(map[string]any)
|
||||||
|
defs, _ := fragment["definitions"].(map[string]any)
|
||||||
|
|
||||||
|
// Sort paths for stable output.
|
||||||
|
sortedPaths := make([]string, 0, len(paths))
|
||||||
|
for p := range paths {
|
||||||
|
sortedPaths = append(sortedPaths, p)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedPaths)
|
||||||
|
|
||||||
|
for _, path := range sortedPaths {
|
||||||
|
methods, ok := paths[path].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sortedMethods := sortMethods(methods)
|
||||||
|
for _, method := range sortedMethods {
|
||||||
|
op, ok := methods[method].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
summary, _ := op["summary"].(string)
|
||||||
|
b.WriteString(fmt.Sprintf("\n## %s %s\n", strings.ToUpper(method), path))
|
||||||
|
if summary != "" {
|
||||||
|
b.WriteString(summary)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameters
|
||||||
|
params, _ := op["parameters"].([]any)
|
||||||
|
bodyParams, nonBodyParams := splitParams(params)
|
||||||
|
|
||||||
|
if len(nonBodyParams) > 0 {
|
||||||
|
b.WriteString("\n**Parameters:**\n")
|
||||||
|
b.WriteString("| Name | In | Type | Required | Description |\n")
|
||||||
|
b.WriteString("|------|----|------|----------|-------------|\n")
|
||||||
|
for _, p := range nonBodyParams {
|
||||||
|
pm, ok := p.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, _ := pm["name"].(string)
|
||||||
|
in, _ := pm["in"].(string)
|
||||||
|
typ, _ := pm["type"].(string)
|
||||||
|
req, _ := pm["required"].(bool)
|
||||||
|
desc, _ := pm["description"].(string)
|
||||||
|
b.WriteString(fmt.Sprintf("| %s | %s | %s | %v | %s |\n", name, in, typ, req, desc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bodyParams) > 0 {
|
||||||
|
for _, p := range bodyParams {
|
||||||
|
pm, ok := p.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
schema, _ := pm["schema"].(map[string]any)
|
||||||
|
refName := resolveRefName(schema)
|
||||||
|
if refName != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("\n**Request body** (`%s`):\n", refName))
|
||||||
|
renderSchemaFields(&b, refName, defs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Responses
|
||||||
|
responses, _ := op["responses"].(map[string]any)
|
||||||
|
if len(responses) > 0 {
|
||||||
|
sortedCodes := make([]string, 0, len(responses))
|
||||||
|
for code := range responses {
|
||||||
|
sortedCodes = append(sortedCodes, code)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedCodes)
|
||||||
|
for _, code := range sortedCodes {
|
||||||
|
resp, ok := responses[code].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
desc, _ := resp["description"].(string)
|
||||||
|
respSchema, _ := resp["schema"].(map[string]any)
|
||||||
|
refName := resolveRefName(respSchema)
|
||||||
|
if refName != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("\n**Response %s** (`%s`): %s\n", code, refName, desc))
|
||||||
|
renderSchemaFields(&b, refName, defs)
|
||||||
|
} else if desc != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("\n**Response %s**: %s\n", code, desc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortMethods returns HTTP methods in a conventional order.
|
||||||
|
func sortMethods(methods map[string]any) []string {
|
||||||
|
order := map[string]int{"get": 0, "post": 1, "put": 2, "patch": 3, "delete": 4}
|
||||||
|
keys := make([]string, 0, len(methods))
|
||||||
|
for k := range methods {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Slice(keys, func(i, j int) bool {
|
||||||
|
oi, oki := order[keys[i]]
|
||||||
|
oj, okj := order[keys[j]]
|
||||||
|
if !oki {
|
||||||
|
oi = 99
|
||||||
|
}
|
||||||
|
if !okj {
|
||||||
|
oj = 99
|
||||||
|
}
|
||||||
|
return oi < oj
|
||||||
|
})
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitParams separates body parameters from non-body parameters.
|
||||||
|
func splitParams(params []any) (body, nonBody []any) {
|
||||||
|
for _, p := range params {
|
||||||
|
pm, ok := p.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if in, _ := pm["in"].(string); in == "body" {
|
||||||
|
body = append(body, p)
|
||||||
|
} else {
|
||||||
|
nonBody = append(nonBody, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRefName extracts the definition name from a $ref or returns "".
|
||||||
|
func resolveRefName(schema map[string]any) string {
|
||||||
|
if schema == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if ref, ok := schema["$ref"].(string); ok {
|
||||||
|
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||||
|
return ref[len(swaggerDefsPrefix):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderSchemaFields writes a markdown field table for a definition.
|
||||||
|
func renderSchemaFields(b *strings.Builder, defName string, defs map[string]any) {
|
||||||
|
if defs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
def, ok := defs[defName].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
props, ok := def["properties"].(map[string]any)
|
||||||
|
if !ok || len(props) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort fields
|
||||||
|
fields := make([]string, 0, len(props))
|
||||||
|
for f := range props {
|
||||||
|
fields = append(fields, f)
|
||||||
|
}
|
||||||
|
sort.Strings(fields)
|
||||||
|
|
||||||
|
b.WriteString("| Field | Type | Description |\n")
|
||||||
|
b.WriteString("|-------|------|-------------|\n")
|
||||||
|
for _, field := range fields {
|
||||||
|
prop, ok := props[field].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typ := schemaTypeString(prop)
|
||||||
|
desc, _ := prop["description"].(string)
|
||||||
|
b.WriteString(fmt.Sprintf("| %s | %s | %s |\n", field, typ, desc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// schemaTypeString returns a human-readable type string for a schema property.
|
||||||
|
func schemaTypeString(prop map[string]any) string {
|
||||||
|
if ref := resolveRefName(prop); ref != "" {
|
||||||
|
return ref
|
||||||
|
}
|
||||||
|
typ, _ := prop["type"].(string)
|
||||||
|
if typ == "array" {
|
||||||
|
items, _ := prop["items"].(map[string]any)
|
||||||
|
if items != nil {
|
||||||
|
if ref := resolveRefName(items); ref != "" {
|
||||||
|
return "[]" + ref
|
||||||
|
}
|
||||||
|
it, _ := items["type"].(string)
|
||||||
|
if it != "" {
|
||||||
|
return "[]" + it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "[]any"
|
||||||
|
}
|
||||||
|
if typ != "" {
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
return "object"
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIInstructionResponse is the JSON response for a single instruction (?format=json).
|
||||||
|
type APIInstructionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
SwaggerFragment map[string]any `json:"swagger_fragment,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAPIInstructionsEndpoint returns all instructions (compact list without guides).
|
||||||
|
// @Summary List available API instruction areas
|
||||||
|
// @Description Returns a compact list of instruction areas with descriptions and URLs for detailed guides
|
||||||
|
// @Tags instructions
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} map[string]any "instructions list with hint"
|
||||||
|
// @Router /api/instructions [get]
|
||||||
|
func ListAPIInstructionsEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
type compactInstruction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
instructions := make([]compactInstruction, len(instructionDefs))
|
||||||
|
for i, s := range instructionDefs {
|
||||||
|
instructions[i] = compactInstruction{
|
||||||
|
Name: s.Name,
|
||||||
|
Description: s.Description,
|
||||||
|
Tags: s.Tags,
|
||||||
|
URL: "/api/instructions/" + s.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
|
"instructions": instructions,
|
||||||
|
"hint": "Fetch GET {url} for a markdown API guide. Add ?format=json for a raw OpenAPI fragment.",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIInstructionEndpoint returns a single instruction by name.
|
||||||
|
// @Summary Get an instruction's API guide or OpenAPI fragment
|
||||||
|
// @Description Returns a markdown guide (default) or filtered OpenAPI fragment (format=json) for a named instruction
|
||||||
|
// @Tags instructions
|
||||||
|
// @Produce json
|
||||||
|
// @Produce text/markdown
|
||||||
|
// @Param name path string true "Instruction name (e.g. chat-inference, config-management)"
|
||||||
|
// @Param format query string false "Response format: json for OpenAPI fragment, omit for markdown"
|
||||||
|
// @Success 200 {object} APIInstructionResponse "instruction documentation"
|
||||||
|
// @Failure 404 {object} map[string]string "instruction not found"
|
||||||
|
// @Router /api/instructions/{name} [get]
|
||||||
|
func GetAPIInstructionEndpoint() echo.HandlerFunc {
|
||||||
|
byName := make(map[string]*instructionDef, len(instructionDefs))
|
||||||
|
for i := range instructionDefs {
|
||||||
|
byName[instructionDefs[i].Name] = &instructionDefs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
name := c.Param("name")
|
||||||
|
inst, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]any{"error": "instruction not found: " + name})
|
||||||
|
}
|
||||||
|
|
||||||
|
swState.init()
|
||||||
|
if !swState.ready {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "swagger spec not available"})
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment := filterSwaggerByTags(swState.spec, inst.Tags)
|
||||||
|
|
||||||
|
format := c.QueryParam("format")
|
||||||
|
if format == "json" {
|
||||||
|
return c.JSON(http.StatusOK, APIInstructionResponse{
|
||||||
|
Name: inst.Name,
|
||||||
|
Description: inst.Description,
|
||||||
|
Tags: inst.Tags,
|
||||||
|
SwaggerFragment: fragment,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
guide := swaggerToMarkdown(inst.Name, inst.Intro, fragment)
|
||||||
|
return c.Blob(http.StatusOK, "text/markdown; charset=utf-8", []byte(guide))
|
||||||
|
}
|
||||||
|
}
|
||||||
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package localai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("API Instructions Endpoints", func() {
|
||||||
|
var app *echo.Echo
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
app = echo.New()
|
||||||
|
app.GET("/api/instructions", ListAPIInstructionsEndpoint())
|
||||||
|
app.GET("/api/instructions/:name", GetAPIInstructionEndpoint())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GET /api/instructions", func() {
|
||||||
|
It("should return all instruction definitions", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(resp).To(HaveKey("hint"))
|
||||||
|
Expect(resp).To(HaveKey("instructions"))
|
||||||
|
|
||||||
|
instructions, ok := resp["instructions"].([]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(instructions).To(HaveLen(9))
|
||||||
|
|
||||||
|
// Verify each instruction has required fields and correct URL format
|
||||||
|
for _, s := range instructions {
|
||||||
|
inst, ok := s.(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(inst["name"]).NotTo(BeEmpty())
|
||||||
|
Expect(inst["description"]).NotTo(BeEmpty())
|
||||||
|
Expect(inst["tags"]).NotTo(BeNil())
|
||||||
|
Expect(inst["url"]).To(HavePrefix("/api/instructions/"))
|
||||||
|
Expect(inst["url"]).To(Equal("/api/instructions/" + inst["name"].(string)))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should include known instruction names", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
|
||||||
|
instructions := resp["instructions"].([]any)
|
||||||
|
names := make([]string, len(instructions))
|
||||||
|
for i, s := range instructions {
|
||||||
|
names[i] = s.(map[string]any)["name"].(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(names).To(ContainElements(
|
||||||
|
"chat-inference",
|
||||||
|
"config-management",
|
||||||
|
"model-management",
|
||||||
|
"monitoring",
|
||||||
|
"agents",
|
||||||
|
))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GET /api/instructions/:name", func() {
|
||||||
|
It("should return 404 for unknown instruction", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/nonexistent", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp["error"]).To(ContainSubstring("instruction not found"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return markdown by default", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(rec.Header().Get("Content-Type")).To(ContainSubstring("text/markdown"))
|
||||||
|
|
||||||
|
body, err := io.ReadAll(rec.Body)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
md := string(body)
|
||||||
|
|
||||||
|
Expect(md).To(HavePrefix("# chat-inference"))
|
||||||
|
// Should contain at least one endpoint heading
|
||||||
|
Expect(md).To(MatchRegexp(`## (GET|POST|PUT|PATCH|DELETE) /`))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should include intro text for instructions that have one", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(rec.Body)
|
||||||
|
// chat-inference has an intro about streaming
|
||||||
|
Expect(string(body)).To(ContainSubstring("stream"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return JSON fragment when format=json", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp["name"]).To(Equal("chat-inference"))
|
||||||
|
Expect(resp["tags"]).To(ContainElements("inference", "embeddings"))
|
||||||
|
|
||||||
|
fragment, ok := resp["swagger_fragment"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(fragment).To(HaveKey("paths"))
|
||||||
|
|
||||||
|
paths, ok := fragment["paths"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(paths).NotTo(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should include referenced definitions in JSON fragment", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
|
||||||
|
fragment := resp["swagger_fragment"].(map[string]any)
|
||||||
|
Expect(fragment).To(HaveKey("definitions"))
|
||||||
|
|
||||||
|
defs, ok := fragment["definitions"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(defs).NotTo(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should only include paths matching the instruction tags in JSON fragment", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/config-management?format=json", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
|
||||||
|
fragment := resp["swagger_fragment"].(map[string]any)
|
||||||
|
paths := fragment["paths"].(map[string]any)
|
||||||
|
Expect(paths).NotTo(BeEmpty())
|
||||||
|
|
||||||
|
// Every operation in every path should have the "config" tag
|
||||||
|
for _, methods := range paths {
|
||||||
|
methodMap := methods.(map[string]any)
|
||||||
|
for _, opRaw := range methodMap {
|
||||||
|
op := opRaw.(map[string]any)
|
||||||
|
tags, _ := op["tags"].([]any)
|
||||||
|
tagStrs := make([]string, len(tags))
|
||||||
|
for i, t := range tags {
|
||||||
|
tagStrs[i] = t.(string)
|
||||||
|
}
|
||||||
|
Expect(tagStrs).To(ContainElement("config"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should produce stable output across calls", func() {
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec1, req1)
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec2, req2)
|
||||||
|
|
||||||
|
body1, _ := io.ReadAll(rec1.Body)
|
||||||
|
body2, _ := io.ReadAll(rec2.Body)
|
||||||
|
Expect(string(body1)).To(Equal(string(body2)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return markdown for every defined instruction", func() {
|
||||||
|
// First get the list
|
||||||
|
listReq := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||||
|
listRec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(listRec, listReq)
|
||||||
|
|
||||||
|
var listResp map[string]any
|
||||||
|
Expect(json.Unmarshal(listRec.Body.Bytes(), &listResp)).To(Succeed())
|
||||||
|
|
||||||
|
instructions := listResp["instructions"].([]any)
|
||||||
|
for _, s := range instructions {
|
||||||
|
name := s.(map[string]any)["name"].(string)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/instructions/"+name, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||||
|
"instruction %q should return 200", name)
|
||||||
|
body, _ := io.ReadAll(rec.Body)
|
||||||
|
Expect(strings.TrimSpace(string(body))).NotTo(BeEmpty(),
|
||||||
|
"instruction %q should return non-empty markdown", name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -15,28 +15,37 @@ import (
|
|||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UpgradeInfoProvider is an interface for querying cached backend upgrade information.
|
||||||
|
type UpgradeInfoProvider interface {
|
||||||
|
GetAvailableUpgrades() map[string]gallery.UpgradeInfo
|
||||||
|
TriggerCheck()
|
||||||
|
}
|
||||||
|
|
||||||
type BackendEndpointService struct {
|
type BackendEndpointService struct {
|
||||||
galleries []config.Gallery
|
galleries []config.Gallery
|
||||||
backendPath string
|
backendPath string
|
||||||
backendSystemPath string
|
backendSystemPath string
|
||||||
backendApplier *galleryop.GalleryService
|
backendApplier *galleryop.GalleryService
|
||||||
|
upgradeChecker UpgradeInfoProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
type GalleryBackend struct {
|
type GalleryBackend struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService {
|
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService, upgradeChecker UpgradeInfoProvider) BackendEndpointService {
|
||||||
return BackendEndpointService{
|
return BackendEndpointService{
|
||||||
galleries: galleries,
|
galleries: galleries,
|
||||||
backendPath: systemState.Backend.BackendsPath,
|
backendPath: systemState.Backend.BackendsPath,
|
||||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||||
backendApplier: backendApplier,
|
backendApplier: backendApplier,
|
||||||
|
upgradeChecker: upgradeChecker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOpStatusEndpoint returns the job status
|
// GetOpStatusEndpoint returns the job status
|
||||||
// @Summary Returns the job status
|
// @Summary Returns the job status
|
||||||
|
// @Tags backends
|
||||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||||
// @Router /backends/jobs/{uuid} [get]
|
// @Router /backends/jobs/{uuid} [get]
|
||||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||||
@@ -51,6 +60,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// GetAllStatusEndpoint returns all the jobs status progress
|
// GetAllStatusEndpoint returns all the jobs status progress
|
||||||
// @Summary Returns all the jobs status progress
|
// @Summary Returns all the jobs status progress
|
||||||
|
// @Tags backends
|
||||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||||
// @Router /backends/jobs [get]
|
// @Router /backends/jobs [get]
|
||||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||||
@@ -61,6 +71,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||||
// @Summary Install backends to LocalAI.
|
// @Summary Install backends to LocalAI.
|
||||||
|
// @Tags backends
|
||||||
// @Param request body GalleryBackend true "query params"
|
// @Param request body GalleryBackend true "query params"
|
||||||
// @Success 200 {object} schema.BackendResponse "Response"
|
// @Success 200 {object} schema.BackendResponse "Response"
|
||||||
// @Router /backends/apply [post]
|
// @Router /backends/apply [post]
|
||||||
@@ -88,6 +99,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// DeleteBackendEndpoint lets delete backends from a LocalAI instance
|
// DeleteBackendEndpoint lets delete backends from a LocalAI instance
|
||||||
// @Summary delete backends from LocalAI.
|
// @Summary delete backends from LocalAI.
|
||||||
|
// @Tags backends
|
||||||
// @Param name path string true "Backend name"
|
// @Param name path string true "Backend name"
|
||||||
// @Success 200 {object} schema.BackendResponse "Response"
|
// @Success 200 {object} schema.BackendResponse "Response"
|
||||||
// @Router /backends/delete/{name} [post]
|
// @Router /backends/delete/{name} [post]
|
||||||
@@ -112,6 +124,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// ListBackendsEndpoint list the available backends configured in LocalAI
|
// ListBackendsEndpoint list the available backends configured in LocalAI
|
||||||
// @Summary List all Backends
|
// @Summary List all Backends
|
||||||
|
// @Tags backends
|
||||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||||
// @Router /backends [get]
|
// @Router /backends [get]
|
||||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||||
@@ -126,6 +139,7 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||||
// @Summary List all Galleries
|
// @Summary List all Galleries
|
||||||
|
// @Tags backends
|
||||||
// @Success 200 {object} []config.Gallery "Response"
|
// @Success 200 {object} []config.Gallery "Response"
|
||||||
// @Router /backends/galleries [get]
|
// @Router /backends/galleries [get]
|
||||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||||
@@ -140,8 +154,65 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUpgradesEndpoint returns the cached backend upgrade information
|
||||||
|
// @Summary Get available backend upgrades
|
||||||
|
// @Tags backends
|
||||||
|
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||||
|
// @Router /backends/upgrades [get]
|
||||||
|
func (mgs *BackendEndpointService) GetUpgradesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if mgs.upgradeChecker == nil {
|
||||||
|
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||||
|
}
|
||||||
|
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUpgradesEndpoint forces an immediate upgrade check
|
||||||
|
// @Summary Force backend upgrade check
|
||||||
|
// @Tags backends
|
||||||
|
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||||
|
// @Router /backends/upgrades/check [post]
|
||||||
|
func (mgs *BackendEndpointService) CheckUpgradesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if mgs.upgradeChecker == nil {
|
||||||
|
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||||
|
}
|
||||||
|
mgs.upgradeChecker.TriggerCheck()
|
||||||
|
// Return current cached results (the triggered check runs async)
|
||||||
|
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpgradeBackendEndpoint triggers an upgrade for a specific backend
|
||||||
|
// @Summary Upgrade a backend
|
||||||
|
// @Tags backends
|
||||||
|
// @Param name path string true "Backend name"
|
||||||
|
// @Success 200 {object} schema.BackendResponse "Response"
|
||||||
|
// @Router /backends/upgrade/{name} [post]
|
||||||
|
func (mgs *BackendEndpointService) UpgradeBackendEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
backendName := c.Param("name")
|
||||||
|
|
||||||
|
uuid, err := uuid.NewUUID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||||
|
ID: uuid.String(),
|
||||||
|
GalleryElementName: backendName,
|
||||||
|
Galleries: mgs.galleries,
|
||||||
|
Upgrade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
|
// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
|
||||||
// @Summary List all available Backends
|
// @Summary List all available Backends
|
||||||
|
// @Tags backends
|
||||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||||
// @Router /backends/available [get]
|
// @Router /backends/available [get]
|
||||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||||
|
|||||||
179
core/http/endpoints/localai/backend_logs.go
Normal file
179
core/http/endpoints/localai/backend_logs.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var backendLogsUpgrader = websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return true // no origin header = same-origin or non-browser
|
||||||
|
}
|
||||||
|
u, err := url.Parse(origin)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return u.Host == r.Host
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// backendLogsConn wraps a websocket connection with a mutex for safe concurrent writes
|
||||||
|
type backendLogsConn struct {
|
||||||
|
*websocket.Conn
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *backendLogsConn) writeJSON(v any) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal error: %w", err)
|
||||||
|
}
|
||||||
|
return c.Conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *backendLogsConn) writePing() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
return c.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListBackendLogsEndpoint returns model IDs that have log buffers
|
||||||
|
// @Summary List models with backend logs
|
||||||
|
// @Description Returns a sorted list of model IDs that have captured backend process output
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {array} string "Model IDs with logs"
|
||||||
|
// @Router /api/backend-logs [get]
|
||||||
|
func ListBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
return c.JSON(200, ml.BackendLogs().ListModels())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBackendLogsEndpoint returns log lines for a specific model
|
||||||
|
// @Summary Get backend logs for a model
|
||||||
|
// @Description Returns all captured log lines (stdout/stderr) for the specified model's backend process
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Produce json
|
||||||
|
// @Param modelId path string true "Model ID"
|
||||||
|
// @Success 200 {array} model.BackendLogLine "Log lines"
|
||||||
|
// @Router /api/backend-logs/{modelId} [get]
|
||||||
|
func GetBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
modelID := c.Param("modelId")
|
||||||
|
return c.JSON(200, ml.BackendLogs().GetLines(modelID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBackendLogsEndpoint clears log lines for a specific model
|
||||||
|
// @Summary Clear backend logs for a model
|
||||||
|
// @Description Removes all captured log lines for the specified model's backend process
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Param modelId path string true "Model ID"
|
||||||
|
// @Success 204 "Logs cleared"
|
||||||
|
// @Router /api/backend-logs/{modelId}/clear [post]
|
||||||
|
func ClearBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
ml.BackendLogs().Clear(c.Param("modelId"))
|
||||||
|
return c.NoContent(204)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendLogsWebSocketEndpoint streams backend logs in real-time over WebSocket
|
||||||
|
// @Summary Stream backend logs via WebSocket
|
||||||
|
// @Description Opens a WebSocket connection for real-time backend log streaming. Sends an initial batch of existing lines (type "initial"), then streams new lines as they appear (type "line"). Supports ping/pong keepalive.
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Param modelId path string true "Model ID"
|
||||||
|
// @Router /ws/backend-logs/{modelId} [get]
|
||||||
|
func BackendLogsWebSocketEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
modelID := c.Param("modelId")
|
||||||
|
|
||||||
|
ws, err := backendLogsUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
ws.SetReadLimit(4096)
|
||||||
|
|
||||||
|
// Set up ping/pong for keepalive
|
||||||
|
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||||
|
ws.SetPongHandler(func(string) error {
|
||||||
|
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
conn := &backendLogsConn{Conn: ws}
|
||||||
|
|
||||||
|
// Send existing lines as initial batch
|
||||||
|
existingLines := ml.BackendLogs().GetLines(modelID)
|
||||||
|
initialMsg := map[string]any{
|
||||||
|
"type": "initial",
|
||||||
|
"lines": existingLines,
|
||||||
|
}
|
||||||
|
if err := conn.writeJSON(initialMsg); err != nil {
|
||||||
|
xlog.Debug("WebSocket backend-logs initial write failed", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to new lines
|
||||||
|
lineCh, unsubscribe := ml.BackendLogs().Subscribe(modelID)
|
||||||
|
defer unsubscribe()
|
||||||
|
|
||||||
|
// Handle close from client side
|
||||||
|
closeCh := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, _, err := ws.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
close(closeCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Ping ticker for keepalive
|
||||||
|
pingTicker := time.NewTicker(30 * time.Second)
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
|
||||||
|
// Forward new lines to WebSocket
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case line, ok := <-lineCh:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lineMsg := map[string]any{
|
||||||
|
"type": "line",
|
||||||
|
"line": line,
|
||||||
|
}
|
||||||
|
if err := conn.writeJSON(lineMsg); err != nil {
|
||||||
|
xlog.Debug("WebSocket backend-logs write error", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case <-pingTicker.C:
|
||||||
|
if err := conn.writePing(); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case <-closeCh:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
package localai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Backend Logs Endpoints", func() {
|
||||||
|
var (
|
||||||
|
app *echo.Echo
|
||||||
|
tempDir string
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tempDir, err = os.MkdirTemp("", "backend-logs-test-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
modelsPath := filepath.Join(tempDir, "models")
|
||||||
|
Expect(os.MkdirAll(modelsPath, 0750)).To(Succeed())
|
||||||
|
|
||||||
|
systemState, err := system.GetSystemState(
|
||||||
|
system.WithModelPath(modelsPath),
|
||||||
|
)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
modelLoader = model.NewModelLoader(systemState)
|
||||||
|
|
||||||
|
app = echo.New()
|
||||||
|
app.GET("/api/backend-logs", ListBackendLogsEndpoint(modelLoader))
|
||||||
|
app.GET("/api/backend-logs/:modelId", GetBackendLogsEndpoint(modelLoader))
|
||||||
|
app.POST("/api/backend-logs/:modelId/clear", ClearBackendLogsEndpoint(modelLoader))
|
||||||
|
app.GET("/ws/backend-logs/:modelId", BackendLogsWebSocketEndpoint(modelLoader))
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("REST endpoints", func() {
|
||||||
|
It("should return empty list of models with logs", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||||
|
Expect(models).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should list models that have logs", func() {
|
||||||
|
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||||
|
Expect(models).To(ContainElement("my-model"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return log lines for a model", func() {
|
||||||
|
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "line one")
|
||||||
|
modelLoader.BackendLogs().AppendLine("my-model", "stderr", "line two")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var lines []model.BackendLogLine
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||||
|
Expect(lines).To(HaveLen(2))
|
||||||
|
Expect(lines[0].Text).To(Equal("line one"))
|
||||||
|
Expect(lines[0].Stream).To(Equal("stdout"))
|
||||||
|
Expect(lines[1].Text).To(Equal("line two"))
|
||||||
|
Expect(lines[1].Stream).To(Equal("stderr"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return empty log lines for unknown model", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/unknown-model", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should clear logs for a model", func() {
|
||||||
|
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/backend-logs/my-model/clear", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||||
|
|
||||||
|
// Verify logs are cleared
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
var lines []model.BackendLogLine
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||||
|
Expect(lines).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("WebSocket endpoint", func() {
|
||||||
|
It("should send initial lines and stream new lines", func() {
|
||||||
|
// Seed some existing lines before connecting
|
||||||
|
modelLoader.BackendLogs().AppendLine("ws-model", "stdout", "existing line")
|
||||||
|
|
||||||
|
// Start a real HTTP server for WebSocket
|
||||||
|
srv := httptest.NewServer(app)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
// Dial the WebSocket
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/ws-model"
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||||
|
conn, _, err := dialer.Dial(wsURL, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Read the initial message
|
||||||
|
var initialMsg map[string]any
|
||||||
|
err = conn.ReadJSON(&initialMsg)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||||
|
|
||||||
|
initialLines, ok := initialMsg["lines"].([]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(initialLines).To(HaveLen(1))
|
||||||
|
|
||||||
|
firstLine := initialLines[0].(map[string]any)
|
||||||
|
Expect(firstLine["text"]).To(Equal("existing line"))
|
||||||
|
|
||||||
|
// Now append a new line and verify it streams through
|
||||||
|
modelLoader.BackendLogs().AppendLine("ws-model", "stderr", "streamed line")
|
||||||
|
|
||||||
|
var lineMsg map[string]any
|
||||||
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
err = conn.ReadJSON(&lineMsg)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(lineMsg["type"]).To(Equal("line"))
|
||||||
|
|
||||||
|
lineData, ok := lineMsg["line"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(lineData["text"]).To(Equal("streamed line"))
|
||||||
|
Expect(lineData["stream"]).To(Equal("stderr"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle connection close gracefully", func() {
|
||||||
|
srv := httptest.NewServer(app)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/close-model"
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||||
|
conn, _, err := dialer.Dial(wsURL, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
// Read initial message
|
||||||
|
var initialMsg map[string]any
|
||||||
|
err = conn.ReadJSON(&initialMsg)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||||
|
|
||||||
|
// Close the connection from client side
|
||||||
|
conn.Close()
|
||||||
|
|
||||||
|
// Give the server goroutine time to detect the close
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// No panic or hang — the test passing is the assertion
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
// BackendMonitorEndpoint returns the status of the specified backend
|
// BackendMonitorEndpoint returns the status of the specified backend
|
||||||
// @Summary Backend monitor endpoint
|
// @Summary Backend monitor endpoint
|
||||||
|
// @Tags monitoring
|
||||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||||
// @Success 200 {object} proto.StatusResponse "Response"
|
// @Success 200 {object} proto.StatusResponse "Response"
|
||||||
// @Router /backend/monitor [get]
|
// @Router /backend/monitor [get]
|
||||||
@@ -29,7 +30,8 @@ func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BackendShutdownEndpoint shuts down the specified backend
|
// BackendShutdownEndpoint shuts down the specified backend
|
||||||
// @Summary Backend monitor endpoint
|
// @Summary Backend shutdown endpoint
|
||||||
|
// @Tags monitoring
|
||||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||||
// @Router /backend/shutdown [post]
|
// @Router /backend/shutdown [post]
|
||||||
func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
||||||
|
|||||||
244
core/http/endpoints/localai/config_meta.go
Normal file
244
core/http/endpoints/localai/config_meta.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"dario.cat/mergo"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/config/meta"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigMetadataEndpoint returns field metadata for config fields.
|
||||||
|
// Without ?section, returns just the section index (lightweight).
|
||||||
|
// With ?section=<id>, returns fields for that section only.
|
||||||
|
// With ?section=all, returns all fields grouped by section.
|
||||||
|
// @Summary List model configuration field metadata
|
||||||
|
// @Description Returns config field metadata. Use ?section=<id> to filter by section, or omit for a section index.
|
||||||
|
// @Tags config
|
||||||
|
// @Produce json
|
||||||
|
// @Param section query string false "Section ID to filter (e.g. 'general', 'llm', 'parameters') or 'all' for everything"
|
||||||
|
// @Success 200 {object} map[string]any "Section index or filtered field metadata"
|
||||||
|
// @Router /api/models/config-metadata [get]
|
||||||
|
func ConfigMetadataEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
sectionParam := c.QueryParam("section")
|
||||||
|
|
||||||
|
// No section param: return lightweight section index.
|
||||||
|
if sectionParam == "" {
|
||||||
|
sections := meta.DefaultSections()
|
||||||
|
type sectionInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
index := make([]sectionInfo, len(sections))
|
||||||
|
for i, s := range sections {
|
||||||
|
index[i] = sectionInfo{
|
||||||
|
ID: s.ID,
|
||||||
|
Label: s.Label,
|
||||||
|
URL: "/api/models/config-metadata?section=" + s.ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
|
"hint": "Fetch a section URL to see its fields. Use ?section=all for everything.",
|
||||||
|
"sections": index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
md := meta.BuildConfigMetadata(reflect.TypeOf(config.ModelConfig{}))
|
||||||
|
|
||||||
|
// section=all: return everything.
|
||||||
|
if sectionParam == "all" {
|
||||||
|
return c.JSON(http.StatusOK, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter to requested section.
|
||||||
|
var filtered []meta.FieldMeta
|
||||||
|
for _, f := range md.Fields {
|
||||||
|
if f.Section == sectionParam {
|
||||||
|
filtered = append(filtered, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown section: " + sectionParam})
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutocompleteEndpoint handles dynamic autocomplete lookups for config fields.
|
||||||
|
// Static option lists (quantizations, cache types, diffusers pipelines/schedulers)
|
||||||
|
// are embedded directly in the field metadata Options; only truly dynamic values
|
||||||
|
// that require runtime lookup are served here.
|
||||||
|
// @Summary Get dynamic autocomplete values for a config field
|
||||||
|
// @Description Returns runtime-resolved values for dynamic providers (backends, models)
|
||||||
|
// @Tags config
|
||||||
|
// @Produce json
|
||||||
|
// @Param provider path string true "Provider name (backends, models, models:chat, models:tts, models:transcript, models:vad)"
|
||||||
|
// @Success 200 {object} map[string]any "values array"
|
||||||
|
// @Router /api/models/config-metadata/autocomplete/{provider} [get]
|
||||||
|
func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
provider := c.Param("provider")
|
||||||
|
var values []string
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case provider == meta.ProviderBackends:
|
||||||
|
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
|
||||||
|
if err == nil {
|
||||||
|
for name := range installedBackends {
|
||||||
|
values = append(values, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(values)
|
||||||
|
|
||||||
|
case provider == meta.ProviderModels:
|
||||||
|
modelConfigs := cl.GetAllModelsConfigs()
|
||||||
|
for _, cfg := range modelConfigs {
|
||||||
|
values = append(values, cfg.Name)
|
||||||
|
}
|
||||||
|
modelsWithoutConfig, _ := galleryop.ListModels(cl, ml, config.NoFilterFn, galleryop.LOOSE_ONLY)
|
||||||
|
values = append(values, modelsWithoutConfig...)
|
||||||
|
sort.Strings(values)
|
||||||
|
|
||||||
|
case strings.HasPrefix(provider, "models:"):
|
||||||
|
capability := strings.TrimPrefix(provider, "models:")
|
||||||
|
var filterFn config.ModelConfigFilterFn
|
||||||
|
switch capability {
|
||||||
|
case "chat":
|
||||||
|
filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT)
|
||||||
|
case "tts":
|
||||||
|
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS)
|
||||||
|
case "vad":
|
||||||
|
filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD)
|
||||||
|
case "transcript":
|
||||||
|
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||||
|
default:
|
||||||
|
filterFn = config.NoFilterFn
|
||||||
|
}
|
||||||
|
filteredConfigs := cl.GetModelConfigsByFilter(filterFn)
|
||||||
|
for _, cfg := range filteredConfigs {
|
||||||
|
values = append(values, cfg.Name)
|
||||||
|
}
|
||||||
|
sort.Strings(values)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown provider: " + provider})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]any{"values": values})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchConfigEndpoint handles PATCH requests to partially update a model config
|
||||||
|
// using nested JSON merge.
|
||||||
|
// @Summary Partially update a model configuration
|
||||||
|
// @Description Deep-merges the JSON patch body into the existing model config
|
||||||
|
// @Tags config
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param name path string true "Model name"
|
||||||
|
// @Success 200 {object} map[string]any "success message"
|
||||||
|
// @Router /api/models/config-json/{name} [patch]
|
||||||
|
func PatchConfigEndpoint(cl *config.ModelConfigLoader, _ *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
modelName := c.Param("name")
|
||||||
|
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||||
|
modelName = decoded
|
||||||
|
}
|
||||||
|
if modelName == "" {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "model name is required"})
|
||||||
|
}
|
||||||
|
|
||||||
|
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||||
|
if !exists {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]any{"error": "model configuration not found"})
|
||||||
|
}
|
||||||
|
|
||||||
|
patchBody, err := io.ReadAll(c.Request().Body)
|
||||||
|
if err != nil || len(patchBody) == 0 {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "request body is empty or unreadable"})
|
||||||
|
}
|
||||||
|
|
||||||
|
var patchMap map[string]any
|
||||||
|
if err := json.Unmarshal(patchBody, &patchMap); err != nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid JSON: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the raw YAML from disk rather than serializing the in-memory config.
|
||||||
|
// The in-memory config has SetDefaults() applied, which would persist
|
||||||
|
// runtime-only defaults (top_p, temperature, mirostat, etc.) to the file.
|
||||||
|
configPath := modelConfig.GetModelConfigFile()
|
||||||
|
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, map[string]any{"error": "config path not trusted: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
diskYAML, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to read config file: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
var existingMap map[string]any
|
||||||
|
if err := yaml.Unmarshal(diskYAML, &existingMap); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to parse existing config: " + err.Error()})
|
||||||
|
}
|
||||||
|
if existingMap == nil {
|
||||||
|
existingMap = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergo.Merge(&existingMap, patchMap, mergo.WithOverride); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to merge configs: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal once and reuse for both validation and writing
|
||||||
|
yamlData, err := yaml.Marshal(existingMap)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal YAML"})
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedConfig config.ModelConfig
|
||||||
|
if err := yaml.Unmarshal(yamlData, &updatedConfig); err != nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "merged config is invalid: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid, err := updatedConfig.Validate(); !valid {
|
||||||
|
errMsg := "validation failed"
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": errMsg})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to write config file"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to reload configs: " + err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||||
|
xlog.Warn("Failed to preload after PATCH", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
|
"success": true,
|
||||||
|
"message": fmt.Sprintf("Model '%s' updated successfully", modelName),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
292
core/http/endpoints/localai/config_meta_test.go
Normal file
292
core/http/endpoints/localai/config_meta_test.go
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
package localai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Config Metadata Endpoints", func() {
|
||||||
|
var (
|
||||||
|
app *echo.Echo
|
||||||
|
tempDir string
|
||||||
|
configLoader *config.ModelConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
appConfig *config.ApplicationConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tempDir, err = os.MkdirTemp("", "config-meta-test-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
systemState, err := system.GetSystemState(
|
||||||
|
system.WithModelPath(tempDir),
|
||||||
|
)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
appConfig = config.NewApplicationConfig(
|
||||||
|
config.WithSystemState(systemState),
|
||||||
|
)
|
||||||
|
configLoader = config.NewModelConfigLoader(tempDir)
|
||||||
|
modelLoader = model.NewModelLoader(systemState)
|
||||||
|
|
||||||
|
app = echo.New()
|
||||||
|
app.GET("/api/models/config-metadata", ConfigMetadataEndpoint())
|
||||||
|
app.GET("/api/models/config-metadata/autocomplete/:provider", AutocompleteEndpoint(configLoader, modelLoader, appConfig))
|
||||||
|
app.PATCH("/api/models/config-json/:name", PatchConfigEndpoint(configLoader, modelLoader, appConfig))
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GET /api/models/config-metadata", func() {
|
||||||
|
It("should return section index when no section param", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp).To(HaveKey("hint"))
|
||||||
|
Expect(resp).To(HaveKey("sections"))
|
||||||
|
|
||||||
|
sections, ok := resp["sections"].([]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(sections).NotTo(BeEmpty())
|
||||||
|
|
||||||
|
// Verify known section IDs are present
|
||||||
|
ids := make([]string, len(sections))
|
||||||
|
for i, s := range sections {
|
||||||
|
sec := s.(map[string]any)
|
||||||
|
Expect(sec).To(HaveKey("id"))
|
||||||
|
Expect(sec).To(HaveKey("label"))
|
||||||
|
Expect(sec).To(HaveKey("url"))
|
||||||
|
ids[i] = sec["id"].(string)
|
||||||
|
}
|
||||||
|
Expect(ids).To(ContainElements("general", "parameters"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return all fields when section=all", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=all", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp).To(HaveKey("fields"))
|
||||||
|
|
||||||
|
fields, ok := resp["fields"].([]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(len(fields)).To(BeNumerically(">=", 80))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should filter by section", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=general", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var fields []map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &fields)).To(Succeed())
|
||||||
|
Expect(fields).NotTo(BeEmpty())
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
Expect(f["section"]).To(Equal("general"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return 404 for unknown section", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=nonexistent", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GET /api/models/config-metadata/autocomplete/:provider", func() {
|
||||||
|
It("should return values for backends provider", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/backends", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp).To(HaveKey("values"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return model names for models provider", func() {
|
||||||
|
// Seed a model config
|
||||||
|
seedConfig := `name: test-model
|
||||||
|
backend: llama-cpp
|
||||||
|
`
|
||||||
|
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||||
|
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/models", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
|
||||||
|
values, ok := resp["values"].([]any)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(values).To(ContainElement("test-model"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return 404 for unknown provider", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/unknown", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("PATCH /api/models/config-json/:name", func() {
|
||||||
|
It("should return 404 for nonexistent model", func() {
|
||||||
|
body := bytes.NewBufferString(`{"backend": "bar"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/nonexistent", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return 400 for empty body", func() {
|
||||||
|
// Seed a model config
|
||||||
|
seedConfig := `name: test-model
|
||||||
|
backend: llama-cpp
|
||||||
|
`
|
||||||
|
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||||
|
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", nil)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return 400 for invalid JSON", func() {
|
||||||
|
seedConfig := `name: test-model
|
||||||
|
backend: llama-cpp
|
||||||
|
`
|
||||||
|
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||||
|
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`not json`)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should merge a field update and persist to disk", func() {
|
||||||
|
seedConfig := `name: test-model
|
||||||
|
backend: llama-cpp
|
||||||
|
`
|
||||||
|
configPath := filepath.Join(tempDir, "test-model.yaml")
|
||||||
|
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||||
|
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"backend": "vllm"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp["success"]).To(BeTrue())
|
||||||
|
|
||||||
|
// Verify the reloaded config has the updated value
|
||||||
|
updatedConfig, exists := configLoader.GetModelConfig("test-model")
|
||||||
|
Expect(exists).To(BeTrue())
|
||||||
|
Expect(updatedConfig.Backend).To(Equal("vllm"))
|
||||||
|
|
||||||
|
// Verify the file on disk was updated
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(string(data)).To(ContainSubstring("vllm"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should not persist runtime defaults (SetDefaults values) to disk", func() {
|
||||||
|
// Create a minimal pipeline config - no sampling params
|
||||||
|
seedConfig := `name: gpt-realtime
|
||||||
|
pipeline:
|
||||||
|
vad: silero-vad
|
||||||
|
transcription: whisper-base
|
||||||
|
llm: llama3
|
||||||
|
tts: piper
|
||||||
|
`
|
||||||
|
configPath := filepath.Join(tempDir, "gpt-realtime.yaml")
|
||||||
|
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||||
|
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||||
|
|
||||||
|
// PATCH with a small change to the pipeline
|
||||||
|
body := bytes.NewBufferString(`{"pipeline": {"tts": "vibevoice"}}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/gpt-realtime", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
// Read the file from disk and verify no spurious defaults leaked
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
fileContent := string(data)
|
||||||
|
|
||||||
|
// The patched value should be present
|
||||||
|
Expect(fileContent).To(ContainSubstring("vibevoice"))
|
||||||
|
|
||||||
|
// Runtime-only defaults from SetDefaults() should NOT be in the file
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("top_p"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("top_k"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("temperature"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("mirostat"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("mmap"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("mmlock"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("threads"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("low_vram"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("embeddings"))
|
||||||
|
Expect(fileContent).NotTo(ContainSubstring("f16"))
|
||||||
|
|
||||||
|
// Original fields should still be present
|
||||||
|
Expect(fileContent).To(ContainSubstring("gpt-realtime"))
|
||||||
|
Expect(fileContent).To(ContainSubstring("silero-vad"))
|
||||||
|
Expect(fileContent).To(ContainSubstring("whisper-base"))
|
||||||
|
Expect(fileContent).To(ContainSubstring("llama3"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package localai
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
@@ -13,6 +15,7 @@ import (
|
|||||||
|
|
||||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
||||||
// @Summary Detects objects in the input image.
|
// @Summary Detects objects in the input image.
|
||||||
|
// @Tags detection
|
||||||
// @Param request body schema.DetectionRequest true "query params"
|
// @Param request body schema.DetectionRequest true "query params"
|
||||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||||
// @Router /v1/detection [post]
|
// @Router /v1/detection [post]
|
||||||
@@ -36,7 +39,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -45,12 +48,18 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
|||||||
Detections: make([]schema.Detection, len(res.Detections)),
|
Detections: make([]schema.Detection, len(res.Detections)),
|
||||||
}
|
}
|
||||||
for i, detection := range res.Detections {
|
for i, detection := range res.Detections {
|
||||||
|
var mask string
|
||||||
|
if len(detection.Mask) > 0 {
|
||||||
|
mask = base64.StdEncoding.EncodeToString(detection.Mask)
|
||||||
|
}
|
||||||
response.Detections[i] = schema.Detection{
|
response.Detections[i] = schema.Detection{
|
||||||
X: detection.X,
|
X: detection.X,
|
||||||
Y: detection.Y,
|
Y: detection.Y,
|
||||||
Width: detection.Width,
|
Width: detection.Width,
|
||||||
Height: detection.Height,
|
Height: detection.Height,
|
||||||
ClassName: detection.ClassName,
|
ClassName: detection.ClassName,
|
||||||
|
Confidence: detection.Confidence,
|
||||||
|
Mask: mask,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGaller
|
|||||||
|
|
||||||
// GetOpStatusEndpoint returns the job status
|
// GetOpStatusEndpoint returns the job status
|
||||||
// @Summary Returns the job status
|
// @Summary Returns the job status
|
||||||
|
// @Tags models
|
||||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||||
// @Router /models/jobs/{uuid} [get]
|
// @Router /models/jobs/{uuid} [get]
|
||||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||||
@@ -54,6 +55,7 @@ func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
|||||||
|
|
||||||
// GetAllStatusEndpoint returns all the jobs status progress
|
// GetAllStatusEndpoint returns all the jobs status progress
|
||||||
// @Summary Returns all the jobs status progress
|
// @Summary Returns all the jobs status progress
|
||||||
|
// @Tags models
|
||||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||||
// @Router /models/jobs [get]
|
// @Router /models/jobs [get]
|
||||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||||
@@ -64,6 +66,7 @@ func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc
|
|||||||
|
|
||||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||||
// @Summary Install models to LocalAI.
|
// @Summary Install models to LocalAI.
|
||||||
|
// @Tags models
|
||||||
// @Param request body GalleryModel true "query params"
|
// @Param request body GalleryModel true "query params"
|
||||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||||
// @Router /models/apply [post]
|
// @Router /models/apply [post]
|
||||||
@@ -93,6 +96,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
|
|||||||
|
|
||||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||||
// @Summary delete models to LocalAI.
|
// @Summary delete models to LocalAI.
|
||||||
|
// @Tags models
|
||||||
// @Param name path string true "Model name"
|
// @Param name path string true "Model name"
|
||||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||||
// @Router /models/delete/{name} [post]
|
// @Router /models/delete/{name} [post]
|
||||||
@@ -118,7 +122,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
|
|||||||
|
|
||||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||||
// @Summary List installable models.
|
// @Summary List installable models.
|
||||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
// @Tags models
|
||||||
|
// @Success 200 {object} []gallery.Metadata "Response"
|
||||||
// @Router /models/available [get]
|
// @Router /models/available [get]
|
||||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
@@ -149,6 +154,7 @@ func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState
|
|||||||
|
|
||||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||||
// @Summary List all Galleries
|
// @Summary List all Galleries
|
||||||
|
// @Tags models
|
||||||
// @Success 200 {object} []config.Gallery "Response"
|
// @Success 200 {object} []config.Gallery "Response"
|
||||||
// @Router /models/galleries [get]
|
// @Router /models/galleries [get]
|
||||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
|
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
|
||||||
//
|
//
|
||||||
// @Summary Get TokenMetrics for Active Slot.
|
// @Summary Get TokenMetrics for Active Slot.
|
||||||
|
// @Tags tokenize
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce audio/x-wav
|
// @Produce audio/x-wav
|
||||||
// @Success 200 {string} binary "generated audio/wav file"
|
// @Success 200 {string} binary "generated audio/wav file"
|
||||||
|
|||||||
@@ -119,48 +119,20 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
|||||||
return c.JSON(http.StatusBadRequest, response)
|
return c.JSON(http.StatusBadRequest, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check content type to determine how to parse
|
// Detect format once and reuse for both typed and map parsing
|
||||||
contentType := c.Request().Header.Get("Content-Type")
|
contentType := c.Request().Header.Get("Content-Type")
|
||||||
var modelConfig config.ModelConfig
|
trimmed := strings.TrimSpace(string(body))
|
||||||
|
isJSON := strings.Contains(contentType, "application/json") ||
|
||||||
|
(!strings.Contains(contentType, "yaml") && len(trimmed) > 0 && trimmed[0] == '{')
|
||||||
|
|
||||||
if strings.Contains(contentType, "application/json") {
|
var modelConfig config.ModelConfig
|
||||||
// Parse JSON
|
if isJSON {
|
||||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||||
response := ModelResponse{
|
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse JSON: " + err.Error()})
|
||||||
Success: false,
|
|
||||||
Error: "Failed to parse JSON: " + err.Error(),
|
|
||||||
}
|
|
||||||
return c.JSON(http.StatusBadRequest, response)
|
|
||||||
}
|
|
||||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
|
||||||
// Parse YAML
|
|
||||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
|
||||||
response := ModelResponse{
|
|
||||||
Success: false,
|
|
||||||
Error: "Failed to parse YAML: " + err.Error(),
|
|
||||||
}
|
|
||||||
return c.JSON(http.StatusBadRequest, response)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Try to auto-detect format
|
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse YAML: " + err.Error()})
|
||||||
// Looks like JSON
|
|
||||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
|
||||||
response := ModelResponse{
|
|
||||||
Success: false,
|
|
||||||
Error: "Failed to parse JSON: " + err.Error(),
|
|
||||||
}
|
|
||||||
return c.JSON(http.StatusBadRequest, response)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Assume YAML
|
|
||||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
|
||||||
response := ModelResponse{
|
|
||||||
Success: false,
|
|
||||||
Error: "Failed to parse YAML: " + err.Error(),
|
|
||||||
}
|
|
||||||
return c.JSON(http.StatusBadRequest, response)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,10 +145,9 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
|||||||
return c.JSON(http.StatusBadRequest, response)
|
return c.JSON(http.StatusBadRequest, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set defaults
|
// Validate without calling SetDefaults() — runtime defaults should not
|
||||||
modelConfig.SetDefaults(appConfig.ToConfigLoaderOptions()...)
|
// be persisted to disk. SetDefaults() is called when loading configs
|
||||||
|
// for inference via LoadModelConfigsFromPath().
|
||||||
// Validate the configuration
|
|
||||||
if valid, _ := modelConfig.Validate(); !valid {
|
if valid, _ := modelConfig.Validate(); !valid {
|
||||||
response := ModelResponse{
|
response := ModelResponse{
|
||||||
Success: false,
|
Success: false,
|
||||||
@@ -195,8 +166,21 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
|||||||
return c.JSON(http.StatusBadRequest, response)
|
return c.JSON(http.StatusBadRequest, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal to YAML for storage
|
// Write only the user-provided fields to disk by parsing the original
|
||||||
yamlData, err := yaml.Marshal(&modelConfig)
|
// body into a map (not the typed struct, which includes Go zero values).
|
||||||
|
var bodyMap map[string]any
|
||||||
|
if isJSON {
|
||||||
|
_ = json.Unmarshal(body, &bodyMap)
|
||||||
|
} else {
|
||||||
|
_ = yaml.Unmarshal(body, &bodyMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
var yamlData []byte
|
||||||
|
if bodyMap != nil {
|
||||||
|
yamlData, err = yaml.Marshal(bodyMap)
|
||||||
|
} else {
|
||||||
|
yamlData, err = yaml.Marshal(&modelConfig)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response := ModelResponse{
|
response := ModelResponse{
|
||||||
Success: false,
|
Success: false,
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ type MCPErrorEvent struct {
|
|||||||
// which handles MCP tool injection and server-side execution.
|
// which handles MCP tool injection and server-side execution.
|
||||||
// Both streaming and non-streaming modes use standard OpenAI response format.
|
// Both streaming and non-streaming modes use standard OpenAI response format.
|
||||||
// @Summary MCP chat completions with automatic tool execution
|
// @Summary MCP chat completions with automatic tool execution
|
||||||
|
// @Tags mcp
|
||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/mcp/chat/completions [post]
|
// @Router /v1/mcp/chat/completions [post]
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
|
|
||||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||||
// @Summary Prometheus metrics endpoint
|
// @Summary Prometheus metrics endpoint
|
||||||
// @Param request body config.Gallery true "Gallery details"
|
// @Tags monitoring
|
||||||
|
// @Produce text/plain
|
||||||
|
// @Success 200 {string} string "Prometheus metrics"
|
||||||
// @Router /metrics [get]
|
// @Router /metrics [get]
|
||||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||||
return echo.WrapHandler(promhttp.Handler())
|
return echo.WrapHandler(promhttp.Handler())
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
// ShowP2PNodes returns the P2P Nodes
|
// ShowP2PNodes returns the P2P Nodes
|
||||||
// @Summary Returns available P2P nodes
|
// @Summary Returns available P2P nodes
|
||||||
|
// @Tags p2p
|
||||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||||
// @Router /api/p2p [get]
|
// @Router /api/p2p [get]
|
||||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
@@ -24,6 +25,7 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
|||||||
|
|
||||||
// ShowP2PToken returns the P2P token
|
// ShowP2PToken returns the P2P token
|
||||||
// @Summary Show the P2P token
|
// @Summary Show the P2P token
|
||||||
|
// @Tags p2p
|
||||||
// @Success 200 {string} string "Response"
|
// @Success 200 {string} string "Response"
|
||||||
// @Router /api/p2p/token [get]
|
// @Router /api/p2p/token [get]
|
||||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
|||||||
144
core/http/endpoints/localai/pin_model.go
Normal file
144
core/http/endpoints/localai/pin_model.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TogglePinnedModelEndpoint handles pinning or unpinning a model.
|
||||||
|
// Pinned models are excluded from idle unloading, LRU eviction, and memory-pressure eviction.
|
||||||
|
//
|
||||||
|
// @Summary Toggle model pinned status
|
||||||
|
// @Description Pin or unpin a model. Pinned models stay loaded and are excluded from automatic eviction.
|
||||||
|
// @Tags config
|
||||||
|
// @Param name path string true "Model name"
|
||||||
|
// @Param action path string true "Action: 'pin' or 'unpin'"
|
||||||
|
// @Success 200 {object} ModelResponse
|
||||||
|
// @Failure 400 {object} ModelResponse
|
||||||
|
// @Failure 404 {object} ModelResponse
|
||||||
|
// @Failure 500 {object} ModelResponse
|
||||||
|
// @Router /api/models/toggle-pinned/{name}/{action} [put]
|
||||||
|
func TogglePinnedModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, syncPinnedFn func()) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
modelName := c.Param("name")
|
||||||
|
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||||
|
modelName = decoded
|
||||||
|
}
|
||||||
|
if modelName == "" {
|
||||||
|
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model name is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
action := c.Param("action")
|
||||||
|
if action != "pin" && action != "unpin" {
|
||||||
|
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Action must be 'pin' or 'unpin'",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing model config
|
||||||
|
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||||
|
if !exists {
|
||||||
|
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the config file path
|
||||||
|
configPath := modelConfig.GetModelConfigFile()
|
||||||
|
if configPath == "" {
|
||||||
|
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration file not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the path is trusted
|
||||||
|
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration not trusted: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing config file
|
||||||
|
configData, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to read configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the YAML config as a generic map to preserve all fields
|
||||||
|
var configMap map[string]interface{}
|
||||||
|
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to parse configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the pinned field
|
||||||
|
pinned := action == "pin"
|
||||||
|
if pinned {
|
||||||
|
configMap["pinned"] = true
|
||||||
|
} else {
|
||||||
|
// Remove the pinned key entirely when unpinning (clean YAML)
|
||||||
|
delete(configMap, "pinned")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back to YAML
|
||||||
|
updatedData, err := yaml.Marshal(configMap)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to serialize configuration: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write updated config back to file
|
||||||
|
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to write configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload model configurations from disk
|
||||||
|
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to reload configurations: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync pinned models to the watchdog
|
||||||
|
if syncPinnedFn != nil {
|
||||||
|
syncPinnedFn()
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Model '%s' has been %sned successfully.", modelName, action)
|
||||||
|
if pinned {
|
||||||
|
msg += " The model will be excluded from automatic eviction."
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, ModelResponse{
|
||||||
|
Success: true,
|
||||||
|
Message: msg,
|
||||||
|
Filename: configPath,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
// SystemInformations returns the system informations
|
// SystemInformations returns the system informations
|
||||||
// @Summary Show the LocalAI instance information
|
// @Summary Show the LocalAI instance information
|
||||||
|
// @Tags monitoring
|
||||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||||
// @Router /system [get]
|
// @Router /system [get]
|
||||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
|||||||
148
core/http/endpoints/localai/toggle_model.go
Normal file
148
core/http/endpoints/localai/toggle_model.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToggleModelEndpoint handles enabling or disabling a model from being loaded on demand.
|
||||||
|
// When disabled, the model remains in the collection but will not be loaded when requested.
|
||||||
|
//
|
||||||
|
// @Summary Toggle model enabled/disabled status
|
||||||
|
// @Description Enable or disable a model from being loaded on demand. Disabled models remain installed but cannot be loaded.
|
||||||
|
// @Tags config
|
||||||
|
// @Param name path string true "Model name"
|
||||||
|
// @Param action path string true "Action: 'enable' or 'disable'"
|
||||||
|
// @Success 200 {object} ModelResponse
|
||||||
|
// @Failure 400 {object} ModelResponse
|
||||||
|
// @Failure 404 {object} ModelResponse
|
||||||
|
// @Failure 500 {object} ModelResponse
|
||||||
|
// @Router /api/models/{name}/{action} [put]
|
||||||
|
func ToggleStateModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
modelName := c.Param("name")
|
||||||
|
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||||
|
modelName = decoded
|
||||||
|
}
|
||||||
|
if modelName == "" {
|
||||||
|
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model name is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
action := c.Param("action")
|
||||||
|
if action != "enable" && action != "disable" {
|
||||||
|
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Action must be 'enable' or 'disable'",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing model config
|
||||||
|
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||||
|
if !exists {
|
||||||
|
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the config file path
|
||||||
|
configPath := modelConfig.GetModelConfigFile()
|
||||||
|
if configPath == "" {
|
||||||
|
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration file not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the path is trusted
|
||||||
|
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Model configuration not trusted: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing config file
|
||||||
|
configData, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to read configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the YAML config as a generic map to preserve all fields
|
||||||
|
var configMap map[string]interface{}
|
||||||
|
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to parse configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the disabled field
|
||||||
|
disabled := action == "disable"
|
||||||
|
if disabled {
|
||||||
|
configMap["disabled"] = true
|
||||||
|
} else {
|
||||||
|
// Remove the disabled key entirely when enabling (clean YAML)
|
||||||
|
delete(configMap, "disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back to YAML
|
||||||
|
updatedData, err := yaml.Marshal(configMap)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to serialize configuration: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write updated config back to file
|
||||||
|
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to write configuration file: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload model configurations from disk
|
||||||
|
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "Failed to reload configurations: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// If disabling, also shutdown the model if it's currently running
|
||||||
|
if disabled {
|
||||||
|
if err := ml.ShutdownModel(modelName); err != nil {
|
||||||
|
// Log but don't fail - the config was saved successfully
|
||||||
|
fmt.Printf("Warning: Failed to shutdown model '%s' during disable: %v\n", modelName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Model '%s' has been %sd successfully.", modelName, action)
|
||||||
|
if disabled {
|
||||||
|
msg += " The model will not be loaded on demand until re-enabled."
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, ModelResponse{
|
||||||
|
Success: true,
|
||||||
|
Message: msg,
|
||||||
|
Filename: configPath,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
// TokenizeEndpoint exposes a REST API to tokenize the content
|
// TokenizeEndpoint exposes a REST API to tokenize the content
|
||||||
// @Summary Tokenize the input.
|
// @Summary Tokenize the input.
|
||||||
|
// @Tags tokenize
|
||||||
// @Param request body schema.TokenizeRequest true "Request"
|
// @Param request body schema.TokenizeRequest true "Request"
|
||||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||||
// @Router /v1/tokenize [post]
|
// @Router /v1/tokenize [post]
|
||||||
|
|||||||
59
core/http/endpoints/localai/traces.go
Normal file
59
core/http/endpoints/localai/traces.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
|
"github.com/mudler/LocalAI/core/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAPITracesEndpoint returns all API request/response traces
|
||||||
|
// @Summary List API request/response traces
|
||||||
|
// @Description Returns captured API exchange traces (request/response pairs) in reverse chronological order
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} map[string]any "Traced API exchanges"
|
||||||
|
// @Router /api/traces [get]
|
||||||
|
func GetAPITracesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
return c.JSON(200, middleware.GetTraces())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAPITracesEndpoint clears all API traces
|
||||||
|
// @Summary Clear API traces
|
||||||
|
// @Description Removes all captured API request/response traces from the buffer
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Success 204 "Traces cleared"
|
||||||
|
// @Router /api/traces/clear [post]
|
||||||
|
func ClearAPITracesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
middleware.ClearTraces()
|
||||||
|
return c.NoContent(204)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBackendTracesEndpoint returns all backend operation traces
|
||||||
|
// @Summary List backend operation traces
|
||||||
|
// @Description Returns captured backend traces (LLM calls, embeddings, TTS, etc.) in reverse chronological order
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} map[string]any "Backend operation traces"
|
||||||
|
// @Router /api/backend-traces [get]
|
||||||
|
func GetBackendTracesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
return c.JSON(200, trace.GetBackendTraces())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBackendTracesEndpoint clears all backend traces
|
||||||
|
// @Summary Clear backend traces
|
||||||
|
// @Description Removes all captured backend operation traces from the buffer
|
||||||
|
// @Tags monitoring
|
||||||
|
// @Success 204 "Traces cleared"
|
||||||
|
// @Router /api/backend-traces/clear [post]
|
||||||
|
func ClearBackendTracesEndpoint() echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
trace.ClearBackendTraces()
|
||||||
|
return c.NoContent(204)
|
||||||
|
}
|
||||||
|
}
|
||||||
55
core/http/endpoints/localai/traces_test.go
Normal file
55
core/http/endpoints/localai/traces_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package localai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Traces Endpoints", func() {
|
||||||
|
var app *echo.Echo
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
app = echo.New()
|
||||||
|
app.GET("/api/traces", GetAPITracesEndpoint())
|
||||||
|
app.POST("/api/traces/clear", ClearAPITracesEndpoint())
|
||||||
|
app.GET("/api/backend-traces", GetBackendTracesEndpoint())
|
||||||
|
app.POST("/api/backend-traces/clear", ClearBackendTracesEndpoint())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return API traces", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/traces", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should clear API traces", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/traces/clear", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return backend traces", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/backend-traces", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should clear backend traces", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/backend-traces/clear", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||||
//
|
//
|
||||||
// @Summary Generates audio from the input text.
|
// @Summary Generates audio from the input text.
|
||||||
|
// @Tags audio
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce audio/x-wav
|
// @Produce audio/x-wav
|
||||||
// @Param request body schema.TTSRequest true "query params"
|
// @Param request body schema.TTSRequest true "query params"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user