mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-21 07:05:07 -04:00
Compare commits
37 Commits
v4.1.3
...
feat/backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fe87cb0d5 | ||
|
|
6dd37a95c4 | ||
|
|
ee00a10836 | ||
|
|
948f3bfaa4 | ||
|
|
1e083cd870 | ||
|
|
b19e60d03a | ||
|
|
4d463e9f0d | ||
|
|
ae4ae5f425 | ||
|
|
7c1865b307 | ||
|
|
62a674ce12 | ||
|
|
c39213443b | ||
|
|
606f462da4 | ||
|
|
5c35e85fe2 | ||
|
|
062e0d0d00 | ||
|
|
d4cd6c284f | ||
|
|
3bb8b65d31 | ||
|
|
9748a1cbc6 | ||
|
|
6bc76dda6d | ||
|
|
e1a6010874 | ||
|
|
706cf5d43c | ||
|
|
13a6ed709c | ||
|
|
85be4ff03c | ||
|
|
b0d9ce4905 | ||
|
|
7081b54c09 | ||
|
|
2b05420f95 | ||
|
|
b64347b6aa | ||
|
|
e00ce981f0 | ||
|
|
285f7d4340 | ||
|
|
ea6e850809 | ||
|
|
b7247fc148 | ||
|
|
39c6b3ed66 | ||
|
|
0e9d1a6588 | ||
|
|
510d6759fe | ||
|
|
154fa000d3 | ||
|
|
0526e60f8d | ||
|
|
db600fb5b2 | ||
|
|
9ac1bdc587 |
111
.agents/adding-gallery-models.md
Normal file
111
.agents/adding-gallery-models.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Adding GGUF Models from HuggingFace to the Gallery
|
||||
|
||||
When adding a GGUF model from HuggingFace to the LocalAI model gallery, follow this guide.
|
||||
|
||||
## Gallery file
|
||||
|
||||
All models are defined in `gallery/index.yaml`. Find the appropriate section (embedding models near other embeddings, chat models near similar chat models) and add a new entry.
|
||||
|
||||
## Getting the SHA256
|
||||
|
||||
GGUF files on HuggingFace expose their SHA256 via the `x-linked-etag` HTTP header. Fetch it with:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/<org>/<repo>/resolve/main/<filename>.gguf" | grep -i x-linked-etag
|
||||
```
|
||||
|
||||
The value (without quotes) is the SHA256 hash. Example:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/resolve/main/embeddinggemma-300m-qat-Q8_0.gguf" | grep -i x-linked-etag
|
||||
# x-linked-etag: "6fa0c02a9c302be6f977521d399b4de3a46310a4f2621ee0063747881b673f67"
|
||||
```
|
||||
|
||||
**Important**: Pay attention to exact filename casing — HuggingFace filenames are case-sensitive (e.g., `Q8_0` vs `q8_0`). Check the repo's file listing to get the exact name.
|
||||
|
||||
## Entry format — Embedding models
|
||||
|
||||
Embedding models use `gallery/virtual.yaml` as the base config and set `embeddings: true`:
|
||||
|
||||
```yaml
|
||||
- name: "model-name"
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
- https://huggingface.co/<original-model-org>/<original-model-name>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo-name>
|
||||
description: |
|
||||
Short description of the model, its size, and capabilities.
|
||||
tags:
|
||||
- embeddings
|
||||
overrides:
|
||||
backend: llama-cpp
|
||||
embeddings: true
|
||||
parameters:
|
||||
model: <filename>.gguf
|
||||
files:
|
||||
- filename: <filename>.gguf
|
||||
uri: huggingface://<gguf-org>/<gguf-repo-name>/<filename>.gguf
|
||||
sha256: <sha256-hash>
|
||||
```
|
||||
|
||||
## Entry format — Chat/LLM models
|
||||
|
||||
Chat models typically reference a template config (e.g., `gallery/gemma.yaml`, `gallery/chatml.yaml`) that defines the prompt format. Use YAML anchors (`&name` / `*name`) if adding multiple quantization variants of the same model:
|
||||
|
||||
```yaml
|
||||
- &model-anchor
|
||||
url: "github:mudler/LocalAI/gallery/<template>.yaml@master"
|
||||
name: "model-name"
|
||||
icon: https://example.com/icon.png
|
||||
license: <license>
|
||||
urls:
|
||||
- https://huggingface.co/<org>/<model>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo>
|
||||
description: |
|
||||
Model description.
|
||||
tags:
|
||||
- llm
|
||||
- gguf
|
||||
- gpu
|
||||
- cpu
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q4_K_M.gguf
|
||||
files:
|
||||
- filename: <filename>-Q4_K_M.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
To add a variant (e.g., different quantization), use YAML merge:
|
||||
|
||||
```yaml
|
||||
- !!merge <<: *model-anchor
|
||||
name: "model-name-q8"
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q8_0.gguf
|
||||
files:
|
||||
- filename: <filename>-Q8_0.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q8_0.gguf
|
||||
```
|
||||
|
||||
## Available template configs
|
||||
|
||||
Look at existing `.yaml` files in `gallery/` to find the right prompt template for your model architecture:
|
||||
|
||||
- `gemma.yaml` — Gemma-family models (gemma, embeddinggemma, etc.)
|
||||
- `chatml.yaml` — ChatML format (many Mistral/OpenHermes models)
|
||||
- `deepseek.yaml` — DeepSeek models
|
||||
- `virtual.yaml` — Minimal base (good for embedding models that don't need chat templates)
|
||||
|
||||
## Checklist
|
||||
|
||||
1. **Find the GGUF file** on HuggingFace — note exact filename (case-sensitive)
|
||||
2. **Get the SHA256** using the `curl -sI` + `x-linked-etag` method above
|
||||
3. **Choose the right template** config from `gallery/` based on model architecture
|
||||
4. **Add the entry** to `gallery/index.yaml` near similar models
|
||||
5. **Set `embeddings: true`** if it's an embedding model
|
||||
6. **Include both URLs** — the original model page and the GGUF repo
|
||||
7. **Write a description** — mention model size, capabilities, and quantization type
|
||||
170
.github/workflows/backend.yml
vendored
170
.github/workflows/backend.yml
vendored
@@ -105,6 +105,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -561,6 +574,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -965,6 +991,32 @@ jobs:
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1108,6 +1160,32 @@ jobs:
|
||||
backend: "stablediffusion-ggml"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-sam3-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1644,6 +1722,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -1842,6 +1946,59 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sam3-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1894,6 +2051,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-sam3-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# whisper
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -34,6 +34,10 @@ jobs:
|
||||
variable: "ACESTEP_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/acestep-cpp/Makefile"
|
||||
- repository: "PABannier/sam3.cpp"
|
||||
variable: "SAM3_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/sam3-cpp/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
23
.github/workflows/test-extra.yml
vendored
23
.github/workflows/test-extra.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -528,3 +529,25 @@ jobs:
|
||||
- name: Test voxtral
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
||||
tests-kokoros:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.kokoros == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake pkg-config protobuf-compiler clang libclang-dev
|
||||
sudo apt-get install -y espeak-ng libespeak-ng-dev libsonic-dev libpcaudio-dev libopus-dev libssl-dev
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
- name: Build kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros kokoros-grpc
|
||||
- name: Test kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros test
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "docs/themes/hugo-theme-relearn"]
|
||||
path = docs/themes/hugo-theme-relearn
|
||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
||||
[submodule "backend/rust/kokoros/sources/Kokoros"]
|
||||
path = backend/rust/kokoros/sources/Kokoros
|
||||
url = https://github.com/lucasjinreal/Kokoros
|
||||
|
||||
@@ -13,6 +13,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
17
Makefile
17
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -148,7 +148,6 @@ test-models/testmodel.ggml:
|
||||
mkdir -p test-dir
|
||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||
wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||
cp tests/models_fixtures/* test-models
|
||||
|
||||
@@ -429,9 +428,11 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/qwen-asr
|
||||
$(MAKE) -C backend/python/nemo
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
$(MAKE) -C backend/python/faster-whisper
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -449,9 +450,11 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/qwen-asr test
|
||||
$(MAKE) -C backend/python/nemo test
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
$(MAKE) -C backend/python/faster-whisper test
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
@@ -587,6 +590,12 @@ BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
|
||||
# Rust backends
|
||||
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
|
||||
# C++ backends (Go wrapper with purego)
|
||||
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
define docker-build-backend
|
||||
@@ -645,12 +654,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -196,6 +196,7 @@ See the full [Backend & Model Compatibility Table](https://localai.io/model-comp
|
||||
- [Build from source](https://localai.io/basics/build/)
|
||||
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||
|
||||
|
||||
39
backend/Dockerfile.rust
Normal file
39
backend/Dockerfile.rust
Normal file
@@ -0,0 +1,39 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BACKEND=kokoros
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
curl unzip \
|
||||
clang \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
espeak-ng libespeak-ng-dev \
|
||||
libsonic-dev libpcaudio-dev \
|
||||
libopus-dev \
|
||||
protobuf-compiler && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN git config --global --add safe.directory /LocalAI
|
||||
|
||||
RUN make -C /LocalAI/backend/rust/${BACKEND} build
|
||||
|
||||
FROM scratch
|
||||
ARG BACKEND=kokoros
|
||||
|
||||
COPY --from=builder /LocalAI/backend/rust/${BACKEND}/package/. ./
|
||||
@@ -444,6 +444,10 @@ message Message {
|
||||
|
||||
message DetectOptions {
|
||||
string src = 1;
|
||||
string prompt = 2; // Text prompt (for SAM 3 PCS mode)
|
||||
repeated float points = 3; // Point coordinates as [x1, y1, label1, x2, y2, label2, ...] (label: 1=pos, 0=neg)
|
||||
repeated float boxes = 4; // Box coordinates as [x1, y1, x2, y2, ...]
|
||||
float threshold = 5; // Detection confidence threshold
|
||||
}
|
||||
|
||||
message Detection {
|
||||
@@ -453,6 +457,7 @@ message Detection {
|
||||
float height = 4;
|
||||
float confidence = 5;
|
||||
string class_name = 6;
|
||||
bytes mask = 7; // PNG-encoded binary segmentation mask
|
||||
}
|
||||
|
||||
message DetectResponse {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=d0a6dfeb28a09831d904fc4d910ddb740da82834
|
||||
LLAMA_VERSION?=e62fa13c2497b2cd1958cb496e9489e86bbd5182
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1614,6 +1614,7 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
@@ -1715,12 +1716,23 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Process first result
|
||||
// Process first result.
|
||||
// When TASK_RESPONSE_TYPE_OAI_CHAT is used, the first token may
|
||||
// produce a JSON array with a role-init element followed by the
|
||||
// actual content element. We must only attach chat deltas to the
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
// delta but no content/reasoning diffs of their own).
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -1744,7 +1756,11 @@ public:
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -2382,6 +2398,7 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
libgosam3*.so
|
||||
sam3-cpp
|
||||
test-models/
|
||||
test-data/
|
||||
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(gosam3 LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# Build ggml as static libraries to avoid runtime .so dependencies
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||
|
||||
set(SAM3_BUILD_EXAMPLES OFF CACHE BOOL "Disable sam3.cpp examples" FORCE)
|
||||
set(SAM3_BUILD_TESTS OFF CACHE BOOL "Disable sam3.cpp tests" FORCE)
|
||||
|
||||
add_subdirectory(./sources/sam3.cpp)
|
||||
|
||||
add_library(gosam3 MODULE gosam3.cpp)
|
||||
target_link_libraries(gosam3 PRIVATE sam3 ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gosam3 PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
target_include_directories(gosam3 PUBLIC
|
||||
sources/sam3.cpp
|
||||
sources/sam3.cpp/ggml/include
|
||||
)
|
||||
|
||||
set_property(TARGET gosam3 PROPERTY CXX_STANDARD 14)
|
||||
set_target_properties(gosam3 PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
122
backend/go/sam3-cpp/Makefile
Normal file
122
backend/go/sam3-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# sam3.cpp
|
||||
SAM3_REPO?=https://github.com/PABannier/sam3.cpp
|
||||
SAM3_VERSION?=01832ef85fcc8eb6488f1d01cd247f07e96ff5a9
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/sam3.cpp:
|
||||
git clone --recursive $(SAM3_REPO) sources/sam3.cpp && \
|
||||
cd sources/sam3.cpp && \
|
||||
git checkout $(SAM3_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgosam3-avx.so libgosam3-avx2.so libgosam3-avx512.so libgosam3-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = libgosam3-fallback.so
|
||||
endif
|
||||
|
||||
sam3-cpp: main.go gosam3.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o sam3-cpp ./
|
||||
|
||||
package: sam3-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgosam3*.so sam3-cpp package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgosam3-avx.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgosam3-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx2.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgosam3-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx512.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgosam3-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgosam3-fallback.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgosam3-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-custom: CMakeLists.txt gosam3.cpp gosam3.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgosam3.so ./$(SO_TARGET)
|
||||
|
||||
all: sam3-cpp package
|
||||
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
#include "sam3.h"
|
||||
#include "gosam3.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
// Static state
|
||||
static std::shared_ptr<sam3_model> g_model;
|
||||
static sam3_state_ptr g_state;
|
||||
static sam3_result g_result;
|
||||
static std::vector<std::vector<unsigned char>> g_mask_pngs;
|
||||
|
||||
// Callback for stbi_write_png_to_mem via stbi_write_png_to_func
|
||||
static void png_write_callback(void *context, void *data, int size) {
|
||||
auto *buf = static_cast<std::vector<unsigned char>*>(context);
|
||||
auto *bytes = static_cast<unsigned char*>(data);
|
||||
buf->insert(buf->end(), bytes, bytes + size);
|
||||
}
|
||||
|
||||
// Encode all masks as PNGs after segmentation
|
||||
static void encode_masks_as_png() {
|
||||
g_mask_pngs.clear();
|
||||
g_mask_pngs.resize(g_result.detections.size());
|
||||
|
||||
for (size_t i = 0; i < g_result.detections.size(); i++) {
|
||||
const auto &mask = g_result.detections[i].mask;
|
||||
if (mask.width > 0 && mask.height > 0 && !mask.data.empty()) {
|
||||
stbi_write_png_to_func(png_write_callback, &g_mask_pngs[i],
|
||||
mask.width, mask.height, 1,
|
||||
mask.data.data(), mask.width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int sam3_cpp_load_model(const char *model_path, int threads) {
|
||||
sam3_params params;
|
||||
params.model_path = model_path;
|
||||
params.n_threads = threads;
|
||||
params.use_gpu = true;
|
||||
|
||||
g_model = sam3_load_model(params);
|
||||
if (!g_model) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load model: %s\n", model_path);
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_state = sam3_create_state(*g_model, params);
|
||||
if (!g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to create state\n");
|
||||
g_model.reset();
|
||||
return 2;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[sam3-cpp] Model loaded: %s (threads=%d)\n", model_path, threads);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_encode_image(const char *image_path) {
|
||||
if (!g_model || !g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Model not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
sam3_image img = sam3_load_image(image_path);
|
||||
if (img.data.empty()) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load image: %s\n", image_path);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (!sam3_encode_image(*g_state, *g_model, img)) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to encode image\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pvs_params pvs_params;
|
||||
|
||||
// Parse points: each triple is [x, y, label]
|
||||
for (int i = 0; i < n_point_triples; i++) {
|
||||
float x = points[i * 3];
|
||||
float y = points[i * 3 + 1];
|
||||
float label = points[i * 3 + 2];
|
||||
sam3_point pt = {x, y};
|
||||
if (label > 0.5f) {
|
||||
pvs_params.pos_points.push_back(pt);
|
||||
} else {
|
||||
pvs_params.neg_points.push_back(pt);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boxes: each quad is [x1, y1, x2, y2], use only first box
|
||||
if (n_box_quads > 0) {
|
||||
pvs_params.box = {boxes[0], boxes[1], boxes[2], boxes[3]};
|
||||
pvs_params.use_box = true;
|
||||
}
|
||||
|
||||
g_result = sam3_segment_pvs(*g_state, *g_model, pvs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// PCS mode requires SAM 3 (full model with text encoder)
|
||||
if (sam3_is_visual_only(*g_model) ||
|
||||
sam3_get_model_type(*g_model) != SAM3_MODEL_SAM3) {
|
||||
fprintf(stderr, "[sam3-cpp] PCS mode requires full SAM 3 model\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pcs_params pcs_params;
|
||||
pcs_params.text_prompt = text_prompt;
|
||||
pcs_params.score_threshold = threshold > 0 ? threshold : 0.5f;
|
||||
|
||||
g_result = sam3_segment_pcs(*g_state, *g_model, pcs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_get_n_detections(void) {
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_x(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_y(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_w(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.x1 - box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_h(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.y1 - box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_score(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].score;
|
||||
}
|
||||
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size) {
|
||||
if (i < 0 || i >= static_cast<int>(g_mask_pngs.size())) return 0;
|
||||
|
||||
const auto &png = g_mask_pngs[i];
|
||||
int size = static_cast<int>(png.size());
|
||||
|
||||
if (buf == nullptr) {
|
||||
return size;
|
||||
}
|
||||
|
||||
int to_copy = size < buf_size ? size : buf_size;
|
||||
memcpy(buf, png.data(), to_copy);
|
||||
return to_copy;
|
||||
}
|
||||
|
||||
void sam3_cpp_free_results(void) {
|
||||
g_result.detections.clear();
|
||||
g_mask_pngs.clear();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
143
backend/go/sam3-cpp/gosam3.go
Normal file
143
backend/go/sam3-cpp/gosam3.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type SAM3 struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int) int
|
||||
CppEncodeImage func(imagePath string) int
|
||||
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
|
||||
CppSegmentPCS func(textPrompt string, threshold float32) int
|
||||
CppGetNDetections func() int
|
||||
CppGetDetectionX func(i int) float32
|
||||
CppGetDetectionY func(i int) float32
|
||||
CppGetDetectionW func(i int) float32
|
||||
CppGetDetectionH func(i int) float32
|
||||
CppGetDetectionScore func(i int) float32
|
||||
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
|
||||
CppFreeResults func()
|
||||
)
|
||||
|
||||
func (s *SAM3) Load(opts *pb.ModelOptions) error {
|
||||
modelFile := opts.ModelFile
|
||||
if modelFile == "" {
|
||||
modelFile = opts.Model
|
||||
}
|
||||
|
||||
var modelPath string
|
||||
if filepath.IsAbs(modelFile) {
|
||||
modelPath = modelFile
|
||||
} else {
|
||||
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
|
||||
threads := int(opts.Threads)
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
}
|
||||
|
||||
ret := CppLoadModel(modelPath, threads)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
// Decode base64 image and write to temp file
|
||||
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "sam3-*.png")
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.Write(imgData); err != nil {
|
||||
tmpFile.Close()
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Encode image
|
||||
ret := CppEncodeImage(tmpFile.Name())
|
||||
if ret != 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
|
||||
}
|
||||
|
||||
threshold := opts.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.5
|
||||
}
|
||||
|
||||
// Determine segmentation mode
|
||||
var nDetections int
|
||||
if opts.Prompt != "" {
|
||||
// Text-prompted segmentation (PCS mode, SAM 3 only)
|
||||
nDetections = CppSegmentPCS(opts.Prompt, threshold)
|
||||
} else {
|
||||
// Point/box-prompted segmentation (PVS mode)
|
||||
var pointsPtr uintptr
|
||||
var boxesPtr uintptr
|
||||
nPointTriples := len(opts.Points) / 3
|
||||
nBoxQuads := len(opts.Boxes) / 4
|
||||
|
||||
if nPointTriples > 0 {
|
||||
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
|
||||
}
|
||||
if nBoxQuads > 0 {
|
||||
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
|
||||
}
|
||||
|
||||
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
|
||||
}
|
||||
|
||||
if nDetections < 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
|
||||
}
|
||||
|
||||
defer CppFreeResults()
|
||||
|
||||
// Build response
|
||||
detections := make([]*pb.Detection, nDetections)
|
||||
for i := 0; i < nDetections; i++ {
|
||||
det := &pb.Detection{
|
||||
X: CppGetDetectionX(i),
|
||||
Y: CppGetDetectionY(i),
|
||||
Width: CppGetDetectionW(i),
|
||||
Height: CppGetDetectionH(i),
|
||||
Confidence: CppGetDetectionScore(i),
|
||||
ClassName: "segment",
|
||||
}
|
||||
|
||||
// Get mask PNG
|
||||
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
|
||||
if maskSize > 0 {
|
||||
maskBuf := make([]byte, maskSize)
|
||||
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
||||
det.Mask = maskBuf
|
||||
}
|
||||
|
||||
detections[i] = det
|
||||
}
|
||||
|
||||
return pb.DetectResponse{
|
||||
Detections: detections,
|
||||
}, nil
|
||||
}
|
||||
51
backend/go/sam3-cpp/gosam3.h
Normal file
51
backend/go/sam3-cpp/gosam3.h
Normal file
@@ -0,0 +1,51 @@
|
||||
#ifndef GOSAM3_H
|
||||
#define GOSAM3_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Load model from file. Returns 0 on success, non-zero on failure.
|
||||
int sam3_cpp_load_model(const char *model_path, int threads);
|
||||
|
||||
// Encode an image from file path. Must be called before segmentation.
|
||||
// Returns 0 on success.
|
||||
int sam3_cpp_encode_image(const char *image_path);
|
||||
|
||||
// Segment with point/box prompts (PVS mode).
|
||||
// points: flat array of [x, y, label] triples (label: 1=positive, 0=negative)
|
||||
// boxes: flat array of [x1, y1, x2, y2] quads
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold);
|
||||
|
||||
// Segment with text prompt (PCS mode, SAM 3 only).
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold);
|
||||
|
||||
// Access detection results (valid after a segment call).
|
||||
int sam3_cpp_get_n_detections(void);
|
||||
|
||||
// Get bounding box for detection i (as x, y, width, height).
|
||||
float sam3_cpp_get_detection_x(int i);
|
||||
float sam3_cpp_get_detection_y(int i);
|
||||
float sam3_cpp_get_detection_w(int i);
|
||||
float sam3_cpp_get_detection_h(int i);
|
||||
|
||||
// Get confidence score for detection i.
|
||||
float sam3_cpp_get_detection_score(int i);
|
||||
|
||||
// Get mask as PNG-encoded bytes.
|
||||
// If buf is NULL, returns the required buffer size.
|
||||
// Otherwise writes up to buf_size bytes and returns bytes written.
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size);
|
||||
|
||||
// Free current detection results.
|
||||
void sam3_cpp_free_results(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GOSAM3_H
|
||||
56
backend/go/sam3-cpp/main.go
Normal file
56
backend/go/sam3-cpp/main.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("SAM3_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgosam3-fallback.so"
|
||||
}
|
||||
|
||||
gosamLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "sam3_cpp_load_model"},
|
||||
{&CppEncodeImage, "sam3_cpp_encode_image"},
|
||||
{&CppSegmentPVS, "sam3_cpp_segment_pvs"},
|
||||
{&CppSegmentPCS, "sam3_cpp_segment_pcs"},
|
||||
{&CppGetNDetections, "sam3_cpp_get_n_detections"},
|
||||
{&CppGetDetectionX, "sam3_cpp_get_detection_x"},
|
||||
{&CppGetDetectionY, "sam3_cpp_get_detection_y"},
|
||||
{&CppGetDetectionW, "sam3_cpp_get_detection_w"},
|
||||
{&CppGetDetectionH, "sam3_cpp_get_detection_h"},
|
||||
{&CppGetDetectionScore, "sam3_cpp_get_detection_score"},
|
||||
{&CppGetDetectionMaskPNG, "sam3_cpp_get_detection_mask_png"},
|
||||
{&CppFreeResults, "sam3_cpp_free_results"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosamLib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &SAM3{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
59
backend/go/sam3-cpp/package.sh
Executable file
59
backend/go/sam3-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/libgosam3-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/sam3-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/sam3-cpp/run.sh
Executable file
52
backend/go/sam3-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgosam3-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export SAM3_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/sam3-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/sam3-cpp "$@"
|
||||
50
backend/go/sam3-cpp/test.sh
Executable file
50
backend/go/sam3-cpp/test.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running sam3-cpp backend tests..."
|
||||
|
||||
# The test requires a SAM model in GGML format.
|
||||
# Uses EdgeTAM Q4_0 (~15MB) for fast CI testing.
|
||||
SAM3_MODEL_DIR="${SAM3_MODEL_DIR:-$CURDIR/test-models}"
|
||||
SAM3_MODEL_FILE="${SAM3_MODEL_FILE:-edgetam_q4_0.ggml}"
|
||||
SAM3_MODEL_URL="${SAM3_MODEL_URL:-https://huggingface.co/PABannier/sam3.cpp/resolve/main/edgetam_q4_0.ggml}"
|
||||
|
||||
# Download model if not present
|
||||
if [ ! -f "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" ]; then
|
||||
echo "Downloading EdgeTAM Q4_0 model for testing..."
|
||||
mkdir -p "$SAM3_MODEL_DIR"
|
||||
curl -L -o "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" "$SAM3_MODEL_URL" --progress-bar
|
||||
echo "Model downloaded."
|
||||
fi
|
||||
|
||||
# Create a test image (4x4 red pixel PNG) using base64
|
||||
# This is a minimal valid PNG for testing the pipeline
|
||||
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||
mkdir -p "$TEST_IMAGE_DIR"
|
||||
|
||||
# Generate a simple test image using Python if available, otherwise use a pre-encoded one
|
||||
if command -v python3 &> /dev/null; then
|
||||
python3 -c "
|
||||
import struct, zlib, base64
|
||||
def create_png(width, height, r, g, b):
|
||||
raw = b''
|
||||
for y in range(height):
|
||||
raw += b'\x00' # filter byte
|
||||
for x in range(width):
|
||||
raw += bytes([r, g, b])
|
||||
def chunk(ctype, data):
|
||||
c = ctype + data
|
||||
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
|
||||
ihdr = struct.pack('>IIBBBBB', width, height, 8, 2, 0, 0, 0)
|
||||
return b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', zlib.compress(raw)) + chunk(b'IEND', b'')
|
||||
with open('$TEST_IMAGE_DIR/test.png', 'wb') as f:
|
||||
f.write(create_png(64, 64, 255, 0, 0))
|
||||
"
|
||||
echo "Test image created."
|
||||
fi
|
||||
|
||||
echo "sam3-cpp test setup complete."
|
||||
echo "Model: $SAM3_MODEL_DIR/$SAM3_MODEL_FILE"
|
||||
echo "Note: Full integration tests run via the LocalAI test-extra target."
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=8afbeb6ba9702c15d41a38296f2ab1fe5c829fa0
|
||||
STABLEDIFFUSION_GGML_VERSION?=e8323cabb0e4511ba18a50b1cb34cf1f87fc71ef
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -125,6 +125,31 @@
|
||||
nvidia-cuda-13: "cuda13-rfdetr"
|
||||
nvidia-cuda-12: "cuda12-rfdetr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
||||
- &sam3cpp
|
||||
name: "sam3-cpp"
|
||||
alias: "sam3-cpp"
|
||||
license: mit
|
||||
description: |
|
||||
Segment Anything Model (SAM 3/2/EdgeTAM) in C/C++ using GGML.
|
||||
Supports text-prompted and point/box-prompted image segmentation.
|
||||
urls:
|
||||
- https://github.com/PABannier/sam3.cpp
|
||||
tags:
|
||||
- image-segmentation
|
||||
- object-detection
|
||||
- sam3
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp"
|
||||
nvidia: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
intel: "intel-sycl-f32-sam3-cpp"
|
||||
vulkan: "vulkan-sam3-cpp"
|
||||
- &vllm
|
||||
name: "vllm"
|
||||
license: apache-2.0
|
||||
@@ -400,12 +425,15 @@
|
||||
license: MIT
|
||||
name: "faster-whisper"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper"
|
||||
nvidia: "cuda12-faster-whisper"
|
||||
intel: "intel-faster-whisper"
|
||||
amd: "rocm-faster-whisper"
|
||||
metal: "metal-faster-whisper"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper"
|
||||
nvidia-cuda-12: "cuda12-faster-whisper"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-faster-whisper"
|
||||
- &moonshine
|
||||
description: |
|
||||
Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
|
||||
@@ -438,6 +466,7 @@
|
||||
- whisperx
|
||||
license: BSD-4-Clause
|
||||
name: "whisperx"
|
||||
alias: "whisperx"
|
||||
capabilities:
|
||||
nvidia: "cuda12-whisperx"
|
||||
amd: "rocm-whisperx"
|
||||
@@ -445,6 +474,8 @@
|
||||
default: "cpu-whisperx"
|
||||
nvidia-cuda-13: "cuda13-whisperx"
|
||||
nvidia-cuda-12: "cuda12-whisperx"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisperx"
|
||||
- &kokoro
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
@@ -468,6 +499,26 @@
|
||||
nvidia-cuda-13: "cuda13-kokoro"
|
||||
nvidia-cuda-12: "cuda12-kokoro"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
|
||||
- &kokoros
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
Kokoros is a pure Rust TTS backend using the Kokoro ONNX model (82M parameters).
|
||||
It provides fast, high-quality text-to-speech with streaming support, built on
|
||||
ONNX Runtime for efficient CPU inference. Supports English, Japanese, Mandarin
|
||||
Chinese, and German.
|
||||
urls:
|
||||
- https://huggingface.co/hexgrad/Kokoro-82M
|
||||
- https://github.com/lucasjinreal/Kokoros
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- Rust
|
||||
- ONNX
|
||||
license: apache-2.0
|
||||
alias: "kokoros"
|
||||
name: "kokoros"
|
||||
capabilities:
|
||||
default: "cpu-kokoros"
|
||||
- &coqui
|
||||
urls:
|
||||
- https://github.com/idiap/coqui-ai-TTS
|
||||
@@ -1602,6 +1653,89 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-rfdetr
|
||||
## sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "sam3-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp-development"
|
||||
nvidia: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
intel: "intel-sycl-f32-sam3-cpp-development"
|
||||
vulkan: "vulkan-sam3-cpp-development"
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
|
||||
## Rerankers
|
||||
- !!merge <<: *rerankers
|
||||
name: "rerankers-development"
|
||||
@@ -2042,15 +2176,32 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-kokoro
|
||||
## kokoros (Rust)
|
||||
- !!merge <<: *kokoros
|
||||
name: "kokoros-development"
|
||||
capabilities:
|
||||
default: "cpu-kokoros-development"
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-kokoros
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-kokoros
|
||||
## faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "faster-whisper-development"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper-development"
|
||||
nvidia: "cuda12-faster-whisper-development"
|
||||
intel: "intel-faster-whisper-development"
|
||||
amd: "rocm-faster-whisper-development"
|
||||
metal: "metal-faster-whisper-development"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
|
||||
@@ -2091,6 +2242,36 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "rocm-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-faster-whisper
|
||||
## moonshine
|
||||
- !!merge <<: *moonshine
|
||||
name: "moonshine-development"
|
||||
@@ -2149,6 +2330,7 @@
|
||||
default: "cpu-whisperx-development"
|
||||
nvidia-cuda-13: "cuda13-whisperx-development"
|
||||
nvidia-cuda-12: "cuda12-whisperx-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx-development"
|
||||
- !!merge <<: *whisperx
|
||||
name: "cpu-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx"
|
||||
@@ -2199,6 +2381,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-whisperx
|
||||
## coqui
|
||||
|
||||
- !!merge <<: *coqui
|
||||
|
||||
@@ -16,4 +16,14 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
faster-whisper
|
||||
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
faster-whisper
|
||||
@@ -147,7 +147,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.language and request.language.strip():
|
||||
language = request.language.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language)
|
||||
context = ""
|
||||
if request.prompt and request.prompt.strip():
|
||||
context = request.prompt.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language, context=context)
|
||||
|
||||
if not results:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
@@ -8,8 +8,21 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
# --index-strategy is a uv-only flag; skip it when using pip
|
||||
if [ "x${USE_PIP}" != "xtrue" ]; then
|
||||
if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match"
|
||||
fi
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
3
backend/python/whisperx/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
3
backend/python/whisperx/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||
3
backend/rust/kokoros/.gitignore
vendored
Normal file
3
backend/rust/kokoros/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/target/
|
||||
/proto/
|
||||
/package/
|
||||
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
3074
backend/rust/kokoros/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
backend/rust/kokoros/Cargo.toml
Normal file
26
backend/rust/kokoros/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "kokoros-grpc"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "kokoros-grpc"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
kokoros = { path = "sources/Kokoros/kokoros" }
|
||||
|
||||
tonic = "0.13"
|
||||
prost = "0.13"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.13"
|
||||
|
||||
[features]
|
||||
default = ["cpu"]
|
||||
cpu = ["kokoros/cpu"]
|
||||
25
backend/rust/kokoros/Makefile
Normal file
25
backend/rust/kokoros/Makefile
Normal file
@@ -0,0 +1,25 @@
|
||||
CURRENT_DIR=$(abspath ./)
|
||||
|
||||
.PHONY: kokoros-grpc
|
||||
kokoros-grpc:
|
||||
mkdir -p $(CURRENT_DIR)/proto
|
||||
cp $(CURRENT_DIR)/../../backend.proto $(CURRENT_DIR)/proto/backend.proto
|
||||
cd $(CURRENT_DIR) && \
|
||||
BACKEND_PROTO_PATH=$(CURRENT_DIR)/proto/backend.proto \
|
||||
cargo build --release
|
||||
|
||||
.PHONY: package
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
.PHONY: test
|
||||
test: kokoros-grpc
|
||||
cd $(CURRENT_DIR) && cargo test
|
||||
|
||||
.PHONY: build
|
||||
build: kokoros-grpc package
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
cargo clean
|
||||
rm -rf package proto
|
||||
15
backend/rust/kokoros/build.rs
Normal file
15
backend/rust/kokoros/build.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let proto_path = std::env::var("BACKEND_PROTO_PATH")
|
||||
.unwrap_or_else(|_| "proto/backend.proto".to_string());
|
||||
|
||||
let proto_dir = std::path::Path::new(&proto_path)
|
||||
.parent()
|
||||
.unwrap_or(std::path::Path::new("."));
|
||||
|
||||
tonic_build::configure()
|
||||
.build_server(true)
|
||||
.build_client(false)
|
||||
.compile_protos(&[&proto_path], &[proto_dir])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
42
backend/rust/kokoros/package.sh
Normal file
42
backend/rust/kokoros/package.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
# Copy the binary and run script
|
||||
cp -avf $CURDIR/target/release/kokoros-grpc $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
chmod +x $CURDIR/package/run.sh
|
||||
|
||||
# Copy espeak-ng data
|
||||
if [ -d "/usr/share/espeak-ng-data" ]; then
|
||||
cp -rf /usr/share/espeak-ng-data $CURDIR/package/
|
||||
elif [ -d "/usr/lib/x86_64-linux-gnu/espeak-ng-data" ]; then
|
||||
cp -rf /usr/lib/x86_64-linux-gnu/espeak-ng-data $CURDIR/package/
|
||||
fi
|
||||
|
||||
# Bundle all dynamic library dependencies
|
||||
echo "Bundling dynamic library dependencies..."
|
||||
ldd $CURDIR/target/release/kokoros-grpc | grep "=>" | awk '{print $3}' | while read lib; do
|
||||
if [ -n "$lib" ] && [ -f "$lib" ]; then
|
||||
cp -avfL "$lib" $CURDIR/package/lib/
|
||||
fi
|
||||
done
|
||||
|
||||
# Copy CA certificates for HTTPS (needed for model auto-download)
|
||||
if [ -d "/etc/ssl/certs" ]; then
|
||||
mkdir -p $CURDIR/package/etc/ssl
|
||||
cp -rf /etc/ssl/certs $CURDIR/package/etc/ssl/
|
||||
fi
|
||||
|
||||
# Copy the dynamic linker
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
23
backend/rust/kokoros/run.sh
Executable file
23
backend/rust/kokoros/run.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:${LD_LIBRARY_PATH:-}
|
||||
|
||||
# SSL certificates for model auto-download
|
||||
if [ -d "$CURDIR/etc/ssl/certs" ]; then
|
||||
export SSL_CERT_DIR=$CURDIR/etc/ssl/certs
|
||||
fi
|
||||
|
||||
# espeak-ng data directory
|
||||
if [ -d "$CURDIR/espeak-ng-data" ]; then
|
||||
export ESPEAK_NG_DATA=$CURDIR/espeak-ng-data
|
||||
fi
|
||||
|
||||
# Use bundled ld.so if present (portability)
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
exec $CURDIR/lib/ld.so $CURDIR/kokoros-grpc "$@"
|
||||
fi
|
||||
|
||||
exec $CURDIR/kokoros-grpc "$@"
|
||||
1
backend/rust/kokoros/sources/Kokoros
Submodule
1
backend/rust/kokoros/sources/Kokoros
Submodule
Submodule backend/rust/kokoros/sources/Kokoros added at 7089168f0c
26
backend/rust/kokoros/src/auth.rs
Normal file
26
backend/rust/kokoros/src/auth.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use tonic::{Request, Status};
|
||||
|
||||
/// Returns an interceptor function if LOCALAI_GRPC_AUTH_TOKEN is set.
|
||||
pub fn make_auth_interceptor(
|
||||
) -> Option<impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone> {
|
||||
let token = std::env::var("LOCALAI_GRPC_AUTH_TOKEN").ok()?;
|
||||
if token.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let expected = format!("Bearer {}", token);
|
||||
Some(
|
||||
move |req: Request<()>| -> Result<Request<()>, Status> {
|
||||
let meta = req.metadata();
|
||||
match meta.get("authorization") {
|
||||
Some(val) => {
|
||||
if val.as_bytes() == expected.as_bytes() {
|
||||
Ok(req)
|
||||
} else {
|
||||
Err(Status::unauthenticated("invalid token"))
|
||||
}
|
||||
}
|
||||
None => Err(Status::unauthenticated("missing authorization")),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
53
backend/rust/kokoros/src/main.rs
Normal file
53
backend/rust/kokoros/src/main.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use clap::Parser;
|
||||
use tonic::transport::Server;
|
||||
|
||||
mod auth;
|
||||
mod service;
|
||||
|
||||
pub mod backend {
|
||||
tonic::include_proto!("backend");
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "kokoros-grpc")]
|
||||
struct Cli {
|
||||
/// gRPC listen address (host:port)
|
||||
#[arg(long, default_value = "localhost:50051")]
|
||||
addr: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_ansi(false)
|
||||
.without_time()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
let addr = cli.addr.parse()?;
|
||||
|
||||
tracing::info!("Starting kokoros gRPC server on {}", addr);
|
||||
|
||||
let mut builder = Server::builder();
|
||||
|
||||
if let Some(interceptor) = auth::make_auth_interceptor() {
|
||||
tracing::info!("Bearer token authentication enabled");
|
||||
let svc = backend::backend_server::BackendServer::with_interceptor(
|
||||
service::KokorosService::default(),
|
||||
interceptor,
|
||||
);
|
||||
builder.add_service(svc).serve(addr).await?;
|
||||
} else {
|
||||
let svc = backend::backend_server::BackendServer::new(service::KokorosService::default())
|
||||
.max_decoding_message_size(50 * 1024 * 1024)
|
||||
.max_encoding_message_size(50 * 1024 * 1024);
|
||||
builder.add_service(svc).serve(addr).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
652
backend/rust/kokoros/src/service.rs
Normal file
652
backend/rust/kokoros/src/service.rs
Normal file
@@ -0,0 +1,652 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use kokoros::tts::koko::TTSKoko;
|
||||
|
||||
use crate::backend;
|
||||
use crate::backend::backend_server::Backend;
|
||||
|
||||
/// Write f32 samples as a standard 44-byte PCM 16-bit WAV file.
|
||||
/// LocalAI's audio pipeline assumes this exact header layout.
|
||||
fn write_pcm16_wav(
|
||||
path: &str,
|
||||
samples: &[f32],
|
||||
sample_rate: u32,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
let num_samples = samples.len() as u32;
|
||||
let data_size = num_samples * 2; // 16-bit = 2 bytes per sample
|
||||
let file_size = 36 + data_size;
|
||||
|
||||
let mut f = File::create(path)?;
|
||||
|
||||
// RIFF header
|
||||
f.write_all(b"RIFF")?;
|
||||
f.write_all(&file_size.to_le_bytes())?;
|
||||
f.write_all(b"WAVE")?;
|
||||
|
||||
// fmt chunk — standard 16-byte PCM format
|
||||
f.write_all(b"fmt ")?;
|
||||
f.write_all(&16u32.to_le_bytes())?; // chunk size
|
||||
f.write_all(&1u16.to_le_bytes())?; // audio format = PCM
|
||||
f.write_all(&1u16.to_le_bytes())?; // channels = mono
|
||||
f.write_all(&sample_rate.to_le_bytes())?;
|
||||
f.write_all(&(sample_rate * 2).to_le_bytes())?; // byte rate
|
||||
f.write_all(&2u16.to_le_bytes())?; // block align
|
||||
f.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// data chunk
|
||||
f.write_all(b"data")?;
|
||||
f.write_all(&data_size.to_le_bytes())?;
|
||||
|
||||
for &s in samples {
|
||||
let clamped = s.clamp(-1.0, 1.0);
|
||||
let pcm = (clamped * 32767.0) as i16;
|
||||
f.write_all(&pcm.to_le_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct KokorosService {
|
||||
tts: Arc<TokioMutex<Option<TTSKoko>>>,
|
||||
language: Arc<Mutex<String>>,
|
||||
speed: Arc<Mutex<f32>>,
|
||||
}
|
||||
|
||||
impl Default for KokorosService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tts: Arc::new(TokioMutex::new(None)),
|
||||
language: Arc::new(Mutex::new("en-us".to_string())),
|
||||
speed: Arc::new(Mutex::new(1.0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl Backend for KokorosService {
|
||||
async fn health(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::Reply>, Status> {
|
||||
Ok(Response::new(backend::Reply {
|
||||
message: b"OK".to_vec(),
|
||||
..Default::default()
|
||||
}))
|
||||
}
|
||||
|
||||
async fn load_model(
|
||||
&self,
|
||||
req: Request<backend::ModelOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
let opts = req.into_inner();
|
||||
|
||||
// Model path: join ModelPath + Model, or just Model
|
||||
let model_path = if !opts.model_path.is_empty() && !opts.model.is_empty() {
|
||||
format!("{}/{}", opts.model_path, opts.model)
|
||||
} else if !opts.model.is_empty() {
|
||||
opts.model.clone()
|
||||
} else {
|
||||
"checkpoints/kokoro-v1.0.onnx".to_string()
|
||||
};
|
||||
|
||||
// Voices data path from AudioPath, or derive from model dir
|
||||
let voices_path = if !opts.audio_path.is_empty() {
|
||||
opts.audio_path.clone()
|
||||
} else {
|
||||
let model_dir = std::path::Path::new(&model_path)
|
||||
.parent()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| ".".to_string());
|
||||
format!("{}/voices-v1.0.bin", model_dir)
|
||||
};
|
||||
|
||||
// Parse options (key:value pairs)
|
||||
for opt in &opts.options {
|
||||
if let Some((key, value)) = opt.split_once(':') {
|
||||
match key {
|
||||
"lang_code" => *self.language.lock().unwrap() = value.to_string(),
|
||||
"speed" => {
|
||||
if let Ok(s) = value.parse::<f32>() {
|
||||
*self.speed.lock().unwrap() = s;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Loading Kokoros model from: {}", model_path);
|
||||
tracing::info!("Loading voices from: {}", voices_path);
|
||||
tracing::info!("Language: {}", self.language.lock().unwrap());
|
||||
|
||||
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||
*self.tts.lock().await = Some(tts);
|
||||
|
||||
tracing::info!("Kokoros TTS model loaded successfully");
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: "Kokoros TTS model loaded".into(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn tts(
|
||||
&self,
|
||||
req: Request<backend::TtsRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
let req = req.into_inner();
|
||||
let tts_guard = self.tts.lock().await;
|
||||
let tts = tts_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?;
|
||||
|
||||
let voice = if req.voice.is_empty() {
|
||||
"af_heart"
|
||||
} else {
|
||||
&req.voice
|
||||
};
|
||||
let lang = req
|
||||
.language
|
||||
.filter(|l| !l.is_empty())
|
||||
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||
let speed = *self.speed.lock().unwrap();
|
||||
|
||||
tracing::info!(
|
||||
text = req.text,
|
||||
voice = voice,
|
||||
lang = lang.as_str(),
|
||||
dst = req.dst,
|
||||
"TTS request received"
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
match tts.tts_raw_audio(&req.text, &lang, voice, speed, None, None, None, None) {
|
||||
Ok(samples) => {
|
||||
let duration_secs = samples.len() as f64 / 24000.0;
|
||||
tracing::info!(
|
||||
num_samples = samples.len(),
|
||||
audio_duration = format!("{:.2}s", duration_secs),
|
||||
inference_time = format!("{:.2}s", start.elapsed().as_secs_f64()),
|
||||
dst = req.dst,
|
||||
"TTS inference complete"
|
||||
);
|
||||
if let Err(e) = write_pcm16_wav(&req.dst, &samples, 24000) {
|
||||
tracing::error!("Failed to write WAV to {}: {}", req.dst, e);
|
||||
return Ok(Response::new(backend::Result {
|
||||
success: false,
|
||||
message: format!("Failed to write WAV: {}", e),
|
||||
}));
|
||||
}
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: String::new(),
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("TTS error: {}", e);
|
||||
Ok(Response::new(backend::Result {
|
||||
success: false,
|
||||
message: format!("TTS error: {}", e),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TTSStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||
|
||||
async fn tts_stream(
|
||||
&self,
|
||||
req: Request<backend::TtsRequest>,
|
||||
) -> Result<Response<Self::TTSStreamStream>, Status> {
|
||||
let req = req.into_inner();
|
||||
let tts_guard = self.tts.lock().await;
|
||||
let tts = tts_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| Status::failed_precondition("Model not loaded"))?
|
||||
.clone();
|
||||
|
||||
let voice = if req.voice.is_empty() {
|
||||
"af_heart".to_string()
|
||||
} else {
|
||||
req.voice
|
||||
};
|
||||
let lang = req
|
||||
.language
|
||||
.filter(|l| !l.is_empty())
|
||||
.unwrap_or_else(|| self.language.lock().unwrap().clone());
|
||||
let speed = *self.speed.lock().unwrap();
|
||||
let text = req.text;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(32);
|
||||
|
||||
// Send sample rate info as first message
|
||||
let tx_clone = tx.clone();
|
||||
let _ = tx_clone
|
||||
.send(Ok(backend::Reply {
|
||||
message: br#"{"sample_rate":24000}"#.to_vec(),
|
||||
..Default::default()
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = tts.tts_raw_audio_streaming(
|
||||
&text,
|
||||
&lang,
|
||||
&voice,
|
||||
speed,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|audio_chunk: Vec<f32>| -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Convert f32 PCM to 16-bit PCM bytes (what LocalAI expects for streaming)
|
||||
let bytes: Vec<u8> = audio_chunk
|
||||
.iter()
|
||||
.flat_map(|&s| {
|
||||
let clamped = s.clamp(-1.0, 1.0);
|
||||
let i16_val = (clamped * 32767.0) as i16;
|
||||
i16_val.to_le_bytes()
|
||||
})
|
||||
.collect();
|
||||
tx.blocking_send(Ok(backend::Reply {
|
||||
audio: bytes,
|
||||
..Default::default()
|
||||
}))
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
|
||||
},
|
||||
);
|
||||
if let Err(e) = result {
|
||||
tracing::error!("TTSStream error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn status(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::StatusResponse>, Status> {
|
||||
let tts = self.tts.lock().await;
|
||||
let state = if tts.is_some() {
|
||||
backend::status_response::State::Ready as i32
|
||||
} else {
|
||||
backend::status_response::State::Uninitialized as i32
|
||||
};
|
||||
Ok(Response::new(backend::StatusResponse {
|
||||
state,
|
||||
memory: None,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn free(
|
||||
&self,
|
||||
_req: Request<backend::HealthMessage>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
*self.tts.lock().await = None;
|
||||
Ok(Response::new(backend::Result {
|
||||
success: true,
|
||||
message: "Model freed".into(),
|
||||
}))
|
||||
}
|
||||
|
||||
// --- Unimplemented RPCs ---
|
||||
|
||||
async fn predict(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::Reply>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type PredictStreamStream = ReceiverStream<Result<backend::Reply, Status>>;
|
||||
|
||||
async fn predict_stream(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<Self::PredictStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn embedding(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::EmbeddingResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn generate_image(
|
||||
&self,
|
||||
_: Request<backend::GenerateImageRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn generate_video(
|
||||
&self,
|
||||
_: Request<backend::GenerateVideoRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_transcription(
|
||||
&self,
|
||||
_: Request<backend::TranscriptRequest>,
|
||||
) -> Result<Response<backend::TranscriptResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn tokenize_string(
|
||||
&self,
|
||||
_: Request<backend::PredictOptions>,
|
||||
) -> Result<Response<backend::TokenizationResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn detect(
|
||||
&self,
|
||||
_: Request<backend::DetectOptions>,
|
||||
) -> Result<Response<backend::DetectResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_set(
|
||||
&self,
|
||||
_: Request<backend::StoresSetOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_delete(
|
||||
&self,
|
||||
_: Request<backend::StoresDeleteOptions>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_get(
|
||||
&self,
|
||||
_: Request<backend::StoresGetOptions>,
|
||||
) -> Result<Response<backend::StoresGetResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_find(
|
||||
&self,
|
||||
_: Request<backend::StoresFindOptions>,
|
||||
) -> Result<Response<backend::StoresFindResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn rerank(
|
||||
&self,
|
||||
_: Request<backend::RerankRequest>,
|
||||
) -> Result<Response<backend::RerankResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn get_metrics(
|
||||
&self,
|
||||
_: Request<backend::MetricsRequest>,
|
||||
) -> Result<Response<backend::MetricsResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn vad(
|
||||
&self,
|
||||
_: Request<backend::VadRequest>,
|
||||
) -> Result<Response<backend::VadResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_encode(
|
||||
&self,
|
||||
_: Request<backend::AudioEncodeRequest>,
|
||||
) -> Result<Response<backend::AudioEncodeResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_decode(
|
||||
&self,
|
||||
_: Request<backend::AudioDecodeRequest>,
|
||||
) -> Result<Response<backend::AudioDecodeResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn model_metadata(
|
||||
&self,
|
||||
_: Request<backend::ModelOptions>,
|
||||
) -> Result<Response<backend::ModelMetadataResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn start_fine_tune(
|
||||
&self,
|
||||
_: Request<backend::FineTuneRequest>,
|
||||
) -> Result<Response<backend::FineTuneJobResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type FineTuneProgressStream = ReceiverStream<Result<backend::FineTuneProgressUpdate, Status>>;
|
||||
|
||||
async fn fine_tune_progress(
|
||||
&self,
|
||||
_: Request<backend::FineTuneProgressRequest>,
|
||||
) -> Result<Response<Self::FineTuneProgressStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stop_fine_tune(
|
||||
&self,
|
||||
_: Request<backend::FineTuneStopRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn list_checkpoints(
|
||||
&self,
|
||||
_: Request<backend::ListCheckpointsRequest>,
|
||||
) -> Result<Response<backend::ListCheckpointsResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn export_model(
|
||||
&self,
|
||||
_: Request<backend::ExportModelRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn start_quantization(
|
||||
&self,
|
||||
_: Request<backend::QuantizationRequest>,
|
||||
) -> Result<Response<backend::QuantizationJobResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type QuantizationProgressStream =
|
||||
ReceiverStream<Result<backend::QuantizationProgressUpdate, Status>>;
|
||||
|
||||
async fn quantization_progress(
|
||||
&self,
|
||||
_: Request<backend::QuantizationProgressRequest>,
|
||||
) -> Result<Response<Self::QuantizationProgressStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stop_quantization(
|
||||
&self,
|
||||
_: Request<backend::QuantizationStopRequest>,
|
||||
) -> Result<Response<backend::Result>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn wav_header_is_standard_pcm16() {
|
||||
let samples = vec![0.0f32, 0.5, -0.5, 1.0, -1.0];
|
||||
let path = std::env::temp_dir().join("kokoros_test.wav");
|
||||
let path_str = path.to_str().unwrap();
|
||||
|
||||
write_pcm16_wav(path_str, &samples, 24000).unwrap();
|
||||
|
||||
let data = std::fs::read(&path).unwrap();
|
||||
std::fs::remove_file(&path).unwrap();
|
||||
|
||||
// Must be exactly 44-byte header + data
|
||||
assert_eq!(data.len(), 44 + samples.len() * 2);
|
||||
|
||||
// RIFF header
|
||||
assert_eq!(&data[0..4], b"RIFF");
|
||||
assert_eq!(&data[8..12], b"WAVE");
|
||||
|
||||
// fmt chunk: 16 bytes, format=1 (PCM), channels=1, 16-bit
|
||||
assert_eq!(&data[12..16], b"fmt ");
|
||||
assert_eq!(u32::from_le_bytes(data[16..20].try_into().unwrap()), 16); // chunk size
|
||||
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM format
|
||||
assert_eq!(u16::from_le_bytes(data[22..24].try_into().unwrap()), 1); // mono
|
||||
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // bits per sample
|
||||
|
||||
// data chunk
|
||||
assert_eq!(&data[36..40], b"data");
|
||||
assert_eq!(
|
||||
u32::from_le_bytes(data[40..44].try_into().unwrap()),
|
||||
(samples.len() * 2) as u32
|
||||
);
|
||||
|
||||
// Verify sample values: 0.5 -> 16383, -0.5 -> -16383, 1.0 -> 32767, -1.0 -> -32767
|
||||
let s1 = i16::from_le_bytes(data[46..48].try_into().unwrap());
|
||||
assert_eq!(s1, 16383); // 0.5 * 32767
|
||||
let s3 = i16::from_le_bytes(data[50..52].try_into().unwrap());
|
||||
assert_eq!(s3, 32767); // 1.0 clamped
|
||||
let s4 = i16::from_le_bytes(data[52..54].try_into().unwrap());
|
||||
assert_eq!(s4, -32767); // -1.0 clamped
|
||||
}
|
||||
|
||||
/// Integration test: runs actual TTS inference and validates the output audio.
|
||||
/// Skipped unless KOKOROS_MODEL_PATH is set to a directory containing
|
||||
/// kokoro-v1.0.onnx and voices-v1.0.bin.
|
||||
#[tokio::test]
|
||||
async fn tts_produces_valid_speech() {
|
||||
let model_dir = match std::env::var("KOKOROS_MODEL_PATH") {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
eprintln!("KOKOROS_MODEL_PATH not set, skipping integration test");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let model_path = format!("{}/kokoro-v1.0.onnx", model_dir);
|
||||
let voices_path = format!("{}/voices-v1.0.bin", model_dir);
|
||||
|
||||
if !std::path::Path::new(&model_path).exists() {
|
||||
eprintln!("Model file not found at {}, skipping", model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
let tts = TTSKoko::new(&model_path, &voices_path).await;
|
||||
|
||||
let input_text = "Hello world, this is a test of speech synthesis.";
|
||||
let out_path = std::env::temp_dir().join("kokoros_integration_test.wav");
|
||||
let out_str = out_path.to_str().unwrap();
|
||||
|
||||
let samples = tts
|
||||
.tts_raw_audio(input_text, "en-us", "af_heart", 1.0, None, None, None, None)
|
||||
.expect("tts_raw_audio failed");
|
||||
|
||||
write_pcm16_wav(out_str, &samples, 24000).unwrap();
|
||||
|
||||
let data = std::fs::read(&out_path).unwrap();
|
||||
std::fs::remove_file(&out_path).unwrap();
|
||||
|
||||
// --- WAV header sanity ---
|
||||
assert_eq!(&data[0..4], b"RIFF");
|
||||
assert_eq!(&data[8..12], b"WAVE");
|
||||
assert_eq!(u16::from_le_bytes(data[20..22].try_into().unwrap()), 1); // PCM
|
||||
assert_eq!(u32::from_le_bytes(data[24..28].try_into().unwrap()), 24000); // sample rate
|
||||
assert_eq!(u16::from_le_bytes(data[34..36].try_into().unwrap()), 16); // 16-bit
|
||||
|
||||
let num_samples = samples.len();
|
||||
let duration_secs = num_samples as f64 / 24000.0;
|
||||
|
||||
// --- Duration check ---
|
||||
// ~10 words should produce roughly 2-8 seconds of speech
|
||||
assert!(
|
||||
duration_secs > 1.0,
|
||||
"Audio too short: {:.2}s for {} words",
|
||||
duration_secs,
|
||||
input_text.split_whitespace().count()
|
||||
);
|
||||
assert!(
|
||||
duration_secs < 15.0,
|
||||
"Audio too long: {:.2}s for {} words",
|
||||
duration_secs,
|
||||
input_text.split_whitespace().count()
|
||||
);
|
||||
|
||||
// --- Energy check: not silence ---
|
||||
let rms = (samples.iter().map(|s| s * s).sum::<f32>() / num_samples as f32).sqrt();
|
||||
assert!(
|
||||
rms > 0.01,
|
||||
"Audio is near-silence: RMS = {:.6}",
|
||||
rms
|
||||
);
|
||||
|
||||
// --- Not clipped/saturated: should have dynamic range ---
|
||||
let max_abs = samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
max_abs < 1.0,
|
||||
"Audio is fully saturated (max |sample| = {:.4})",
|
||||
max_abs
|
||||
);
|
||||
assert!(
|
||||
max_abs > 0.05,
|
||||
"Audio has very low amplitude (max |sample| = {:.4})",
|
||||
max_abs
|
||||
);
|
||||
|
||||
// --- Speech-like spectral check ---
|
||||
// Speech should have significant energy variation (not white noise or DC).
|
||||
// Check that the signal has zero-crossings in a speech-like range (roughly
|
||||
// 50-400 crossings per 24000 samples = 100-8000 Hz fundamental range).
|
||||
let zero_crossings: usize = samples
|
||||
.windows(2)
|
||||
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
|
||||
.count();
|
||||
let crossings_per_sec = zero_crossings as f64 / duration_secs;
|
||||
// White noise at 24kHz would have ~12000 crossings/sec.
|
||||
// Speech is typically 100-4000 crossings/sec.
|
||||
assert!(
|
||||
crossings_per_sec < 10000.0,
|
||||
"Too many zero crossings ({:.0}/s) — likely noise, not speech",
|
||||
crossings_per_sec
|
||||
);
|
||||
assert!(
|
||||
crossings_per_sec > 50.0,
|
||||
"Too few zero crossings ({:.0}/s) — likely DC or silence, not speech",
|
||||
crossings_per_sec
|
||||
);
|
||||
|
||||
eprintln!(
|
||||
"Integration test passed: duration={:.2}s, rms={:.4}, max={:.4}, zero_crossings={:.0}/s",
|
||||
duration_secs, rms, max_abs, crossings_per_sec
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,9 @@ type Application struct {
|
||||
|
||||
// Distributed mode services (nil when not in distributed mode)
|
||||
distributed *DistributedServices
|
||||
|
||||
// Upgrade checker (background service for detecting backend upgrades)
|
||||
upgradeChecker *UpgradeChecker
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -79,6 +82,19 @@ func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
||||
return a.agentJobService
|
||||
}
|
||||
|
||||
func (a *Application) UpgradeChecker() *UpgradeChecker {
|
||||
return a.upgradeChecker
|
||||
}
|
||||
|
||||
// distributedDB returns the PostgreSQL database for distributed coordination,
|
||||
// or nil in standalone mode.
|
||||
func (a *Application) distributedDB() *gorm.DB {
|
||||
if a.distributed != nil {
|
||||
return a.authDB
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
@@ -335,6 +335,9 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
|
||||
appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||
}
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
appConfig.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||
}
|
||||
if settings.ApiKeys != nil {
|
||||
// API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
|
||||
// If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
|
||||
|
||||
@@ -231,6 +231,15 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
xlog.Error("error registering external backends", "error", err)
|
||||
}
|
||||
|
||||
// Start background upgrade checker for backends.
|
||||
// In distributed mode, uses PostgreSQL advisory lock so only one frontend
|
||||
// instance runs periodic checks (avoids duplicate upgrades across replicas).
|
||||
if len(options.BackendGalleries) > 0 {
|
||||
uc := NewUpgradeChecker(options, application.ModelLoader(), application.distributedDB())
|
||||
application.upgradeChecker = uc
|
||||
go uc.Run(options.Context)
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
xlog.Error("error loading config file", "error", err)
|
||||
|
||||
198
core/application/upgrade_checker.go
Normal file
198
core/application/upgrade_checker.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UpgradeChecker periodically checks for backend upgrades and optionally
|
||||
// auto-upgrades them. It caches the last check results for API queries.
|
||||
//
|
||||
// In standalone mode it runs a simple ticker loop.
|
||||
// In distributed mode it uses a PostgreSQL advisory lock so that only one
|
||||
// frontend instance performs periodic checks and auto-upgrades at a time.
|
||||
type UpgradeChecker struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelLoader *model.ModelLoader
|
||||
galleries []config.Gallery
|
||||
systemState *system.SystemState
|
||||
db *gorm.DB // non-nil in distributed mode
|
||||
|
||||
checkInterval time.Duration
|
||||
stop chan struct{}
|
||||
done chan struct{}
|
||||
triggerCh chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
lastUpgrades map[string]gallery.UpgradeInfo
|
||||
lastCheckTime time.Time
|
||||
}
|
||||
|
||||
// NewUpgradeChecker creates a new UpgradeChecker service.
|
||||
// Pass db=nil for standalone mode, or a *gorm.DB for distributed mode
|
||||
// (uses advisory locks so only one instance runs periodic checks).
|
||||
func NewUpgradeChecker(appConfig *config.ApplicationConfig, ml *model.ModelLoader, db *gorm.DB) *UpgradeChecker {
|
||||
return &UpgradeChecker{
|
||||
appConfig: appConfig,
|
||||
modelLoader: ml,
|
||||
galleries: appConfig.BackendGalleries,
|
||||
systemState: appConfig.SystemState,
|
||||
db: db,
|
||||
checkInterval: 6 * time.Hour,
|
||||
stop: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
triggerCh: make(chan struct{}, 1),
|
||||
lastUpgrades: make(map[string]gallery.UpgradeInfo),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the upgrade checker loop. It waits 30 seconds after startup,
|
||||
// performs an initial check, then re-checks every 6 hours.
|
||||
//
|
||||
// In distributed mode, periodic checks are guarded by a PostgreSQL advisory
|
||||
// lock so only one frontend instance runs them. On-demand triggers (TriggerCheck)
|
||||
// and the initial check always run locally for fast API response cache warming.
|
||||
func (uc *UpgradeChecker) Run(ctx context.Context) {
|
||||
defer close(uc.done)
|
||||
|
||||
// Initial delay: don't slow down startup
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
}
|
||||
|
||||
// First check always runs locally (to warm the cache on this instance)
|
||||
uc.runCheck(ctx)
|
||||
|
||||
if uc.db != nil {
|
||||
// Distributed mode: use advisory lock for periodic checks.
|
||||
// RunLeaderLoop ticks every checkInterval; only the lock holder executes.
|
||||
go advisorylock.RunLeaderLoop(ctx, uc.db, advisorylock.KeyBackendUpgradeCheck, uc.checkInterval, func() {
|
||||
uc.runCheck(ctx)
|
||||
})
|
||||
|
||||
// Still listen for on-demand triggers (from API / settings change)
|
||||
// and stop signal — these run on every instance.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-uc.triggerCh:
|
||||
uc.runCheck(ctx)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Standalone mode: simple ticker loop
|
||||
ticker := time.NewTicker(uc.checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uc.stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
uc.runCheck(ctx)
|
||||
case <-uc.triggerCh:
|
||||
uc.runCheck(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the upgrade checker loop.
|
||||
func (uc *UpgradeChecker) Shutdown() {
|
||||
close(uc.stop)
|
||||
<-uc.done
|
||||
}
|
||||
|
||||
// TriggerCheck forces an immediate upgrade check on this instance.
|
||||
func (uc *UpgradeChecker) TriggerCheck() {
|
||||
select {
|
||||
case uc.triggerCh <- struct{}{}:
|
||||
default:
|
||||
// Already triggered, skip
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableUpgrades returns the cached upgrade check results.
|
||||
func (uc *UpgradeChecker) GetAvailableUpgrades() map[string]gallery.UpgradeInfo {
|
||||
uc.mu.RLock()
|
||||
defer uc.mu.RUnlock()
|
||||
|
||||
// Return a copy to avoid races
|
||||
result := make(map[string]gallery.UpgradeInfo, len(uc.lastUpgrades))
|
||||
for k, v := range uc.lastUpgrades {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
upgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState)
|
||||
|
||||
uc.mu.Lock()
|
||||
uc.lastCheckTime = time.Now()
|
||||
if err != nil {
|
||||
xlog.Debug("Backend upgrade check failed", "error", err)
|
||||
uc.mu.Unlock()
|
||||
return
|
||||
}
|
||||
uc.lastUpgrades = upgrades
|
||||
uc.mu.Unlock()
|
||||
|
||||
if len(upgrades) == 0 {
|
||||
xlog.Debug("All backends up to date")
|
||||
return
|
||||
}
|
||||
|
||||
// Log available upgrades
|
||||
for name, info := range upgrades {
|
||||
if info.AvailableVersion != "" {
|
||||
xlog.Info("Backend upgrade available",
|
||||
"backend", name,
|
||||
"installed", info.InstalledVersion,
|
||||
"available", info.AvailableVersion)
|
||||
} else {
|
||||
xlog.Info("Backend upgrade available (new build)",
|
||||
"backend", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-upgrade if enabled
|
||||
if uc.appConfig.AutoUpgradeBackends {
|
||||
for name, info := range upgrades {
|
||||
xlog.Info("Auto-upgrading backend", "backend", name,
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
if err := gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil); err != nil {
|
||||
xlog.Error("Failed to auto-upgrade backend",
|
||||
"backend", name, "error", err)
|
||||
} else {
|
||||
xlog.Info("Backend upgraded successfully", "backend", name,
|
||||
"version", info.AvailableVersion)
|
||||
}
|
||||
}
|
||||
// Re-check to update cache after upgrades
|
||||
if freshUpgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState); err == nil {
|
||||
uc.mu.Lock()
|
||||
uc.lastUpgrades = freshUpgrades
|
||||
uc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,27 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// SyncPinnedModelsToWatchdog reads pinned status from all model configs and updates the watchdog
|
||||
func (a *Application) SyncPinnedModelsToWatchdog() {
|
||||
cl := a.ModelConfigLoader()
|
||||
if cl == nil {
|
||||
return
|
||||
}
|
||||
wd := a.modelLoader.GetWatchDog()
|
||||
if wd == nil {
|
||||
return
|
||||
}
|
||||
configs := cl.GetAllModelsConfigs()
|
||||
var pinned []string
|
||||
for _, cfg := range configs {
|
||||
if cfg.IsPinned() {
|
||||
pinned = append(pinned, cfg.Name)
|
||||
}
|
||||
}
|
||||
wd.SetPinnedModels(pinned)
|
||||
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
|
||||
}
|
||||
|
||||
func (a *Application) StopWatchdog() error {
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
@@ -44,6 +65,9 @@ func (a *Application) startWatchdog() error {
|
||||
// Set the watchdog on the model loader
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Sync pinned models from config to the watchdog
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
// But memory reclaimer needs the Run() loop for periodic checking
|
||||
@@ -124,5 +148,8 @@ func (a *Application) RestartWatchdog() error {
|
||||
newWD.RestoreState(oldState)
|
||||
}
|
||||
|
||||
// Re-sync pinned models after restart
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,10 @@ import (
|
||||
|
||||
func Detection(
|
||||
sourceFile string,
|
||||
prompt string,
|
||||
points []float32,
|
||||
boxes []float32,
|
||||
threshold float32,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -35,7 +39,11 @@ func Detection(
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
Src: sourceFile,
|
||||
Prompt: prompt,
|
||||
Points: points,
|
||||
Boxes: boxes,
|
||||
Threshold: threshold,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
|
||||
@@ -40,10 +40,17 @@ type BackendsUninstall struct {
|
||||
BackendsCMDFlags `embed:""`
|
||||
}
|
||||
|
||||
type BackendsUpgrade struct {
|
||||
BackendArgs []string `arg:"" optional:"" name:"backends" help:"Backend names to upgrade (empty = upgrade all)"`
|
||||
|
||||
BackendsCMDFlags `embed:""`
|
||||
}
|
||||
|
||||
type BackendsCMD struct {
|
||||
List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"`
|
||||
Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"`
|
||||
Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"`
|
||||
Upgrade BackendsUpgrade `cmd:"" help:"Upgrade backends to latest versions"`
|
||||
}
|
||||
|
||||
func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||
@@ -64,11 +71,27 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for upgrades
|
||||
upgrades, _ := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
|
||||
for _, backend := range backends {
|
||||
versionStr := ""
|
||||
if backend.Version != "" {
|
||||
versionStr = " v" + backend.Version
|
||||
}
|
||||
if backend.Installed {
|
||||
fmt.Printf(" * %s@%s (installed)\n", backend.Gallery.Name, backend.Name)
|
||||
if info, ok := upgrades[backend.Name]; ok {
|
||||
upgradeStr := info.AvailableVersion
|
||||
if upgradeStr == "" {
|
||||
upgradeStr = "new build"
|
||||
}
|
||||
fmt.Printf(" * %s@%s%s (installed, upgrade available: %s)\n", backend.Gallery.Name, backend.Name, versionStr, upgradeStr)
|
||||
} else {
|
||||
fmt.Printf(" * %s@%s%s (installed)\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf(" - %s@%s\n", backend.Gallery.Name, backend.Name)
|
||||
fmt.Printf(" - %s@%s%s\n", backend.Gallery.Name, backend.Name, versionStr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -111,6 +134,79 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bu *BackendsUpgrade) Run(ctx *cliContext.Context) error {
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(bu.BackendGalleries), &galleries); err != nil {
|
||||
xlog.Error("unable to load galleries", "error", err)
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(bu.BackendsSystemPath),
|
||||
system.WithBackendPath(bu.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
upgrades, err := gallery.CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for upgrades: %w", err)
|
||||
}
|
||||
|
||||
if len(upgrades) == 0 {
|
||||
fmt.Println("All backends are up to date.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter to specified backends if args given
|
||||
toUpgrade := upgrades
|
||||
if len(bu.BackendArgs) > 0 {
|
||||
toUpgrade = make(map[string]gallery.UpgradeInfo)
|
||||
for _, name := range bu.BackendArgs {
|
||||
if info, ok := upgrades[name]; ok {
|
||||
toUpgrade[name] = info
|
||||
} else {
|
||||
fmt.Printf("Backend %s: no upgrade available\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toUpgrade) == 0 {
|
||||
fmt.Println("No upgrades to apply.")
|
||||
return nil
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
for name, info := range toUpgrade {
|
||||
versionStr := ""
|
||||
if info.AvailableVersion != "" {
|
||||
versionStr = " to v" + info.AvailableVersion
|
||||
}
|
||||
fmt.Printf("Upgrading %s%s...\n", name, versionStr)
|
||||
|
||||
progressBar := progressbar.NewOptions(
|
||||
1000,
|
||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading %s", name)),
|
||||
progressbar.OptionShowBytes(false),
|
||||
progressbar.OptionClearOnFinish(),
|
||||
)
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
v := int(percentage * 10)
|
||||
if err := progressBar.Set(v); err != nil {
|
||||
xlog.Error("error updating progress bar", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback); err != nil {
|
||||
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
||||
} else {
|
||||
fmt.Printf("Backend %s upgraded successfully\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
|
||||
for _, backendName := range bu.BackendArgs {
|
||||
xlog.Info("uninstalling backend", "backend", backendName)
|
||||
|
||||
@@ -47,6 +47,7 @@ type RunCMD struct {
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
@@ -62,6 +63,7 @@ type RunCMD struct {
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
@@ -295,6 +297,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.DisableWebUI)
|
||||
}
|
||||
|
||||
if r.OllamaAPIRootEndpoint {
|
||||
opts = append(opts, config.EnableOllamaAPIRootEndpoint)
|
||||
}
|
||||
|
||||
if r.DisableGalleryEndpoint {
|
||||
opts = append(opts, config.DisableGalleryEndpoint)
|
||||
}
|
||||
@@ -485,6 +491,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableBackendGalleriesAutoload)
|
||||
}
|
||||
|
||||
if r.AutoUpgradeBackends {
|
||||
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
||||
}
|
||||
|
||||
if r.PreloadBackendOnly {
|
||||
_, err := application.New(opts...)
|
||||
return err
|
||||
|
||||
@@ -40,6 +40,7 @@ type ApplicationConfig struct {
|
||||
Federated bool
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
OpaqueErrors bool
|
||||
UseSubtleKeyComparison bool
|
||||
@@ -56,6 +57,7 @@ type ApplicationConfig struct {
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries, AutoloadBackendGalleries bool
|
||||
AutoUpgradeBackends bool
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
@@ -263,6 +265,10 @@ var DisableWebUI = func(o *ApplicationConfig) {
|
||||
o.DisableWebUI = true
|
||||
}
|
||||
|
||||
var EnableOllamaAPIRootEndpoint = func(o *ApplicationConfig) {
|
||||
o.OllamaAPIRootEndpoint = true
|
||||
}
|
||||
|
||||
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
||||
o.DisableRuntimeSettings = true
|
||||
}
|
||||
@@ -385,6 +391,10 @@ var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) {
|
||||
o.AutoloadBackendGalleries = true
|
||||
}
|
||||
|
||||
func WithAutoUpgradeBackends(v bool) AppOption {
|
||||
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
||||
}
|
||||
|
||||
var EnableFederated = func(o *ApplicationConfig) {
|
||||
o.Federated = true
|
||||
}
|
||||
@@ -857,6 +867,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
backendGalleries := o.BackendGalleries
|
||||
autoloadGalleries := o.AutoloadGalleries
|
||||
autoloadBackendGalleries := o.AutoloadBackendGalleries
|
||||
autoUpgradeBackends := o.AutoUpgradeBackends
|
||||
apiKeys := o.ApiKeys
|
||||
agentJobRetentionDays := o.AgentJobRetentionDays
|
||||
|
||||
@@ -930,6 +941,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
BackendGalleries: &backendGalleries,
|
||||
AutoloadGalleries: &autoloadGalleries,
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
AutoUpgradeBackends: &autoUpgradeBackends,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
@@ -1078,6 +1090,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.AutoloadBackendGalleries != nil {
|
||||
o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||
}
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
o.AutoUpgradeBackends = *settings.AutoUpgradeBackends
|
||||
}
|
||||
if settings.AgentJobRetentionDays != nil {
|
||||
o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
|
||||
}
|
||||
|
||||
@@ -119,6 +119,13 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(*rs.AgentJobRetentionDays).To(Equal(30))
|
||||
})
|
||||
|
||||
It("should include auto_upgrade_backends", func() {
|
||||
appConfig := &ApplicationConfig{AutoUpgradeBackends: true}
|
||||
rs := appConfig.ToRuntimeSettings()
|
||||
Expect(rs.AutoUpgradeBackends).ToNot(BeNil())
|
||||
Expect(*rs.AutoUpgradeBackends).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use default timeouts when not set", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
|
||||
@@ -426,6 +433,14 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(appConfig.AutoloadBackendGalleries).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should apply auto_upgrade_backends setting", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
v := true
|
||||
rs := &RuntimeSettings{AutoUpgradeBackends: &v}
|
||||
appConfig.ApplyRuntimeSettings(rs)
|
||||
Expect(appConfig.AutoUpgradeBackends).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should apply agent settings", func() {
|
||||
appConfig := &ApplicationConfig{}
|
||||
|
||||
@@ -465,6 +480,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Federated: true,
|
||||
AutoloadGalleries: true,
|
||||
AutoloadBackendGalleries: false,
|
||||
AutoUpgradeBackends: true,
|
||||
AgentJobRetentionDays: 60,
|
||||
}
|
||||
|
||||
@@ -496,6 +512,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(target.Federated).To(Equal(original.Federated))
|
||||
Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries))
|
||||
Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries))
|
||||
Expect(target.AutoUpgradeBackends).To(Equal(original.AutoUpgradeBackends))
|
||||
Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays))
|
||||
})
|
||||
|
||||
|
||||
@@ -49,6 +49,22 @@ var DiffusersPipelineOptions = []FieldOption{
|
||||
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||
}
|
||||
|
||||
var UsecaseOptions = []FieldOption{
|
||||
{Value: "chat", Label: "Chat"},
|
||||
{Value: "completion", Label: "Completion"},
|
||||
{Value: "edit", Label: "Edit"},
|
||||
{Value: "embeddings", Label: "Embeddings"},
|
||||
{Value: "rerank", Label: "Rerank"},
|
||||
{Value: "image", Label: "Image"},
|
||||
{Value: "transcript", Label: "Transcript"},
|
||||
{Value: "tts", Label: "TTS"},
|
||||
{Value: "sound_generation", Label: "Sound Generation"},
|
||||
{Value: "tokenize", Label: "Tokenize"},
|
||||
{Value: "vad", Label: "VAD"},
|
||||
{Value: "video", Label: "Video"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
}
|
||||
|
||||
var DiffusersSchedulerOptions = []FieldOption{
|
||||
{Value: "ddim", Label: "DDIM"},
|
||||
{Value: "ddpm", Label: "DDPM"},
|
||||
|
||||
@@ -47,8 +47,9 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
"known_usecases": {
|
||||
Section: "general",
|
||||
Label: "Known Use Cases",
|
||||
Description: "Capabilities this model supports (e.g. FLAG_CHAT, FLAG_COMPLETION)",
|
||||
Description: "Capabilities this model supports",
|
||||
Component: "string-list",
|
||||
Options: UsecaseOptions,
|
||||
Order: 6,
|
||||
},
|
||||
|
||||
@@ -287,6 +288,15 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 72,
|
||||
},
|
||||
|
||||
// --- TTS ---
|
||||
"tts.voice": {
|
||||
Section: "tts",
|
||||
Label: "Voice",
|
||||
Description: "Default voice for TTS output",
|
||||
Component: "input",
|
||||
Order: 90,
|
||||
},
|
||||
|
||||
// --- Diffusers ---
|
||||
"diffusers.pipeline_type": {
|
||||
Section: "diffusers",
|
||||
|
||||
@@ -77,6 +77,8 @@ type ModelConfig struct {
|
||||
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
||||
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
|
||||
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
|
||||
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
@@ -548,6 +550,16 @@ func (c *ModelConfig) GetModelTemplate() string {
|
||||
return c.modelTemplate
|
||||
}
|
||||
|
||||
// IsDisabled returns true if the model is disabled
|
||||
func (c *ModelConfig) IsDisabled() bool {
|
||||
return c.Disabled != nil && *c.Disabled
|
||||
}
|
||||
|
||||
// IsPinned returns true if the model is pinned (excluded from idle unloading and eviction)
|
||||
func (c *ModelConfig) IsPinned() bool {
|
||||
return c.Pinned != nil && *c.Pinned
|
||||
}
|
||||
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
@@ -705,7 +717,8 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
|
||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||
if c.Backend != "rfdetr" {
|
||||
detectionBackends := []string{"rfdetr", "sam3-cpp"}
|
||||
if !slices.Contains(detectionBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ type RuntimeSettings struct {
|
||||
// Backend management
|
||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
AutoUpgradeBackends *bool `json:"auto_upgrade_backends,omitempty"` // Automatically upgrade backends when new versions are detected
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
|
||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||
|
||||
@@ -20,12 +20,19 @@ type BackendMetadata struct {
|
||||
GalleryURL string `json:"gallery_url,omitempty"`
|
||||
// InstalledAt is the timestamp when the backend was installed
|
||||
InstalledAt string `json:"installed_at,omitempty"`
|
||||
// Version is the version of the backend at install time
|
||||
Version string `json:"version,omitempty"`
|
||||
// URI is the original URI used to install the backend
|
||||
URI string `json:"uri,omitempty"`
|
||||
// Digest is the OCI image digest at install time (for upgrade detection)
|
||||
Digest string `json:"digest,omitempty"`
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
Metadata `json:",inline" yaml:",inline"`
|
||||
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
||||
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
||||
Version string `json:"version,omitempty" yaml:"version,omitempty"`
|
||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||
}
|
||||
@@ -71,6 +78,10 @@ func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool
|
||||
return true
|
||||
}
|
||||
|
||||
if systemState.CapabilityFilterDisabled() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Meta backends are compatible if the system capability matches one of the keys
|
||||
if m.IsMeta() {
|
||||
capability := systemState.Capability(m.CapabilitiesMap)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
cp "github.com/otiai10/copy"
|
||||
@@ -158,6 +159,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
|
||||
Name: name,
|
||||
GalleryURL: backend.Gallery.URL,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Version: bestBackend.Version,
|
||||
}
|
||||
|
||||
if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil {
|
||||
@@ -279,6 +281,18 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
Name: name,
|
||||
GalleryURL: config.Gallery.URL,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Version: config.Version,
|
||||
URI: string(uri),
|
||||
}
|
||||
|
||||
// Record the OCI digest for upgrade detection (non-fatal on failure)
|
||||
if uri.LooksLikeOCI() {
|
||||
digest, digestErr := oci.GetImageDigest(string(uri), "", nil, nil)
|
||||
if digestErr != nil {
|
||||
xlog.Warn("Failed to get OCI image digest for backend", "uri", string(uri), "error", digestErr)
|
||||
} else {
|
||||
metadata.Digest = digest
|
||||
}
|
||||
}
|
||||
|
||||
if config.Alias != "" {
|
||||
@@ -373,11 +387,13 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
}
|
||||
|
||||
type SystemBackend struct {
|
||||
Name string
|
||||
RunFile string
|
||||
IsMeta bool
|
||||
IsSystem bool
|
||||
Metadata *BackendMetadata
|
||||
Name string
|
||||
RunFile string
|
||||
IsMeta bool
|
||||
IsSystem bool
|
||||
Metadata *BackendMetadata
|
||||
UpgradeAvailable bool `json:"upgrade_available,omitempty"`
|
||||
AvailableVersion string `json:"available_version,omitempty"`
|
||||
}
|
||||
|
||||
type SystemBackends map[string]SystemBackend
|
||||
|
||||
118
core/gallery/backends_version_test.go
Normal file
118
core/gallery/backends_version_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Backend versioning", func() {
|
||||
var tempDir string
|
||||
var systemState *system.SystemState
|
||||
var modelLoader *model.ModelLoader
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "gallery-version-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err = system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
It("records version in metadata when installing a backend with a version", func() {
|
||||
// Create a fake backend source directory with a run.sh
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend"
|
||||
backend.URI = srcDir
|
||||
backend.Version = "1.2.3"
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Read the metadata file and check version
|
||||
metadataPath := filepath.Join(tempDir, "test-backend", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(metadata["version"]).To(Equal("1.2.3"))
|
||||
})
|
||||
|
||||
It("records URI in metadata", func() {
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend-uri"
|
||||
backend.URI = srcDir
|
||||
backend.Version = "2.0.0"
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metadataPath := filepath.Join(tempDir, "test-backend-uri", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(metadata["uri"]).To(Equal(srcDir))
|
||||
})
|
||||
|
||||
It("omits version key when version is empty", func() {
|
||||
srcDir, err := os.MkdirTemp("", "gallery-version-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(srcDir)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho ok"), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := &gallery.GalleryBackend{}
|
||||
backend.Name = "test-backend-noversion"
|
||||
backend.URI = srcDir
|
||||
// Version intentionally left empty
|
||||
|
||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metadataPath := filepath.Join(tempDir, "test-backend-noversion", "metadata.json")
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var metadata map[string]any
|
||||
err = json.Unmarshal(data, &metadata)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// omitempty should exclude the version key entirely
|
||||
_, hasVersion := metadata["version"]
|
||||
Expect(hasVersion).To(BeFalse())
|
||||
})
|
||||
})
|
||||
237
core/gallery/upgrade.go
Normal file
237
core/gallery/upgrade.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
cp "github.com/otiai10/copy"
|
||||
)
|
||||
|
||||
// UpgradeInfo holds details about an available backend upgrade.
|
||||
type UpgradeInfo struct {
|
||||
BackendName string `json:"backend_name"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
AvailableVersion string `json:"available_version"`
|
||||
InstalledDigest string `json:"installed_digest,omitempty"`
|
||||
AvailableDigest string `json:"available_digest,omitempty"`
|
||||
}
|
||||
|
||||
// CheckBackendUpgrades compares installed backends against gallery entries
|
||||
// and returns a map of backend names to UpgradeInfo for those that have
|
||||
// newer versions or different OCI digests available.
|
||||
func CheckBackendUpgrades(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState) (map[string]UpgradeInfo, error) {
|
||||
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list available backends: %w", err)
|
||||
}
|
||||
|
||||
installedBackends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list installed backends: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[string]UpgradeInfo)
|
||||
|
||||
for _, installed := range installedBackends {
|
||||
// Skip system backends — they are managed outside the gallery
|
||||
if installed.IsSystem {
|
||||
continue
|
||||
}
|
||||
if installed.Metadata == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find matching gallery entry by metadata name
|
||||
galleryEntry := FindGalleryElement(galleryBackends, installed.Metadata.Name)
|
||||
if galleryEntry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
installedVersion := installed.Metadata.Version
|
||||
galleryVersion := galleryEntry.Version
|
||||
|
||||
// If both sides have versions, compare them
|
||||
if galleryVersion != "" && installedVersion != "" {
|
||||
if galleryVersion != installedVersion {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledVersion: installedVersion,
|
||||
AvailableVersion: galleryVersion,
|
||||
}
|
||||
}
|
||||
// Versions match — no upgrade needed
|
||||
continue
|
||||
}
|
||||
|
||||
// Gallery has a version but installed doesn't — this happens for backends
|
||||
// installed before version tracking was added. Flag as upgradeable so
|
||||
// users can re-install to pick up version metadata.
|
||||
if galleryVersion != "" && installedVersion == "" {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledVersion: "",
|
||||
AvailableVersion: galleryVersion,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Fall back to OCI digest comparison when versions are unavailable
|
||||
if downloader.URI(galleryEntry.URI).LooksLikeOCI() {
|
||||
remoteDigest, err := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to get remote OCI digest for upgrade check", "backend", installed.Metadata.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
// If we have a stored digest, compare; otherwise any remote digest
|
||||
// means we can't confirm we're up to date — flag as upgradeable
|
||||
if installed.Metadata.Digest == "" || remoteDigest != installed.Metadata.Digest {
|
||||
result[installed.Metadata.Name] = UpgradeInfo{
|
||||
BackendName: installed.Metadata.Name,
|
||||
InstalledDigest: installed.Metadata.Digest,
|
||||
AvailableDigest: remoteDigest,
|
||||
}
|
||||
}
|
||||
}
|
||||
// No version info and non-OCI URI — cannot determine, skip
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpgradeBackend upgrades a single backend to the latest gallery version using
|
||||
// an atomic swap with backup-based rollback on failure.
|
||||
func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, galleries []config.Gallery, backendName string, downloadStatus func(string, string, string, float64)) error {
|
||||
// Look up the installed backend
|
||||
installedBackends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list installed backends: %w", err)
|
||||
}
|
||||
|
||||
installed, ok := installedBackends.Get(backendName)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", backendName, ErrBackendNotFound)
|
||||
}
|
||||
|
||||
if installed.IsSystem {
|
||||
return fmt.Errorf("system backend %q cannot be upgraded via gallery", backendName)
|
||||
}
|
||||
|
||||
// If this is a meta backend, recursively upgrade the concrete backend it points to
|
||||
if installed.Metadata != nil && installed.Metadata.MetaBackendFor != "" {
|
||||
xlog.Info("Meta backend detected, upgrading concrete backend", "meta", backendName, "concrete", installed.Metadata.MetaBackendFor)
|
||||
return UpgradeBackend(ctx, systemState, modelLoader, galleries, installed.Metadata.MetaBackendFor, downloadStatus)
|
||||
}
|
||||
|
||||
// Find the gallery entry
|
||||
galleryBackends, err := AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list available backends: %w", err)
|
||||
}
|
||||
|
||||
galleryEntry := FindGalleryElement(galleryBackends, backendName)
|
||||
if galleryEntry == nil {
|
||||
return fmt.Errorf("no gallery entry found for backend %q", backendName)
|
||||
}
|
||||
|
||||
backendPath := filepath.Join(systemState.Backend.BackendsPath, backendName)
|
||||
tmpPath := backendPath + ".upgrade-tmp"
|
||||
backupPath := backendPath + ".backup"
|
||||
|
||||
// Clean up any stale tmp/backup dirs from prior attempts
|
||||
os.RemoveAll(tmpPath)
|
||||
os.RemoveAll(backupPath)
|
||||
|
||||
// Step 1: Download the new backend into the tmp directory
|
||||
if err := os.MkdirAll(tmpPath, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create upgrade tmp dir: %w", err)
|
||||
}
|
||||
|
||||
uri := downloader.URI(galleryEntry.URI)
|
||||
if uri.LooksLikeDir() {
|
||||
if err := cp.Copy(string(uri), tmpPath); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to copy backend from directory: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := uri.DownloadFileWithContext(ctx, tmpPath, "", 1, 1, downloadStatus); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to download backend: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Validate — check that run.sh exists in the new content
|
||||
newRunFile := filepath.Join(tmpPath, runFile)
|
||||
if _, err := os.Stat(newRunFile); os.IsNotExist(err) {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("upgrade validation failed: run.sh not found in new backend")
|
||||
}
|
||||
|
||||
// Step 3: Atomic swap — rename current to backup, then tmp to current
|
||||
if err := os.Rename(backendPath, backupPath); err != nil {
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to move current backend to backup: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, backendPath); err != nil {
|
||||
// Restore backup on failure
|
||||
xlog.Error("Failed to move new backend into place, restoring backup", "error", err)
|
||||
if restoreErr := os.Rename(backupPath, backendPath); restoreErr != nil {
|
||||
xlog.Error("Failed to restore backup", "error", restoreErr)
|
||||
}
|
||||
os.RemoveAll(tmpPath)
|
||||
return fmt.Errorf("failed to move new backend into place: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Write updated metadata, preserving alias from old metadata
|
||||
var oldAlias string
|
||||
if installed.Metadata != nil {
|
||||
oldAlias = installed.Metadata.Alias
|
||||
}
|
||||
|
||||
newMetadata := &BackendMetadata{
|
||||
Name: backendName,
|
||||
Version: galleryEntry.Version,
|
||||
URI: galleryEntry.URI,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
Alias: oldAlias,
|
||||
}
|
||||
|
||||
if galleryEntry.Gallery.URL != "" {
|
||||
newMetadata.GalleryURL = galleryEntry.Gallery.URL
|
||||
}
|
||||
|
||||
// Record OCI digest if applicable (non-fatal on failure)
|
||||
if uri.LooksLikeOCI() {
|
||||
digest, digestErr := oci.GetImageDigest(galleryEntry.URI, "", nil, nil)
|
||||
if digestErr != nil {
|
||||
xlog.Warn("Failed to get OCI image digest after upgrade", "uri", galleryEntry.URI, "error", digestErr)
|
||||
} else {
|
||||
newMetadata.Digest = digest
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeBackendMetadata(backendPath, newMetadata); err != nil {
|
||||
// Metadata write failure is not worth rolling back the entire upgrade
|
||||
xlog.Error("Failed to write metadata after upgrade", "error", err)
|
||||
}
|
||||
|
||||
// Step 5: Re-register backends so the model loader picks up any changes
|
||||
if err := RegisterBackends(systemState, modelLoader); err != nil {
|
||||
xlog.Warn("Failed to re-register backends after upgrade", "error", err)
|
||||
}
|
||||
|
||||
// Step 6: Remove backup
|
||||
os.RemoveAll(backupPath)
|
||||
|
||||
xlog.Info("Backend upgraded successfully", "backend", backendName, "version", galleryEntry.Version)
|
||||
return nil
|
||||
}
|
||||
219
core/gallery/upgrade_test.go
Normal file
219
core/gallery/upgrade_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var _ = Describe("Upgrade Detection and Execution", func() {
|
||||
var (
|
||||
tempDir string
|
||||
backendsPath string
|
||||
galleryPath string
|
||||
systemState *system.SystemState
|
||||
galleries []config.Gallery
|
||||
)
|
||||
|
||||
// installBackendWithVersion creates a fake installed backend directory with
|
||||
// the given name, version, and optional run.sh content.
|
||||
installBackendWithVersion := func(name, version string, runContent ...string) {
|
||||
dir := filepath.Join(backendsPath, name)
|
||||
Expect(os.MkdirAll(dir, 0750)).To(Succeed())
|
||||
|
||||
content := "#!/bin/sh\necho ok"
|
||||
if len(runContent) > 0 {
|
||||
content = runContent[0]
|
||||
}
|
||||
Expect(os.WriteFile(filepath.Join(dir, "run.sh"), []byte(content), 0755)).To(Succeed())
|
||||
|
||||
metadata := BackendMetadata{
|
||||
Name: name,
|
||||
Version: version,
|
||||
InstalledAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, err := json.MarshalIndent(metadata, "", " ")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(dir, "metadata.json"), data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
// writeGalleryYAML writes a gallery YAML file with the given backends.
|
||||
writeGalleryYAML := func(backends []GalleryBackend) {
|
||||
data, err := yaml.Marshal(backends)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(galleryPath, data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "upgrade-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backendsPath = tempDir
|
||||
|
||||
galleryPath = filepath.Join(tempDir, "gallery.yaml")
|
||||
|
||||
// Write a default empty gallery
|
||||
writeGalleryYAML([]GalleryBackend{})
|
||||
|
||||
galleries = []config.Gallery{
|
||||
{
|
||||
Name: "test-gallery",
|
||||
URL: "file://" + galleryPath,
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err = system.GetSystemState(
|
||||
system.WithBackendPath(backendsPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Describe("CheckBackendUpgrades", func() {
|
||||
It("should detect upgrade when gallery version differs from installed version", func() {
|
||||
// Install a backend at v1.0.0
|
||||
installBackendWithVersion("my-backend", "1.0.0")
|
||||
|
||||
// Gallery advertises v2.0.0
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(HaveKey("my-backend"))
|
||||
Expect(upgrades["my-backend"].InstalledVersion).To(Equal("1.0.0"))
|
||||
Expect(upgrades["my-backend"].AvailableVersion).To(Equal("2.0.0"))
|
||||
})
|
||||
|
||||
It("should NOT flag upgrade when versions match", func() {
|
||||
installBackendWithVersion("my-backend", "2.0.0")
|
||||
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should skip backends without version info and without OCI digest", func() {
|
||||
// Install without version
|
||||
installBackendWithVersion("my-backend", "")
|
||||
|
||||
// Gallery also without version
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: filepath.Join(tempDir, "some-source"),
|
||||
},
|
||||
})
|
||||
|
||||
upgrades, err := CheckBackendUpgrades(context.Background(), galleries, systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(upgrades).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UpgradeBackend", func() {
|
||||
It("should replace backend directory and update metadata", func() {
|
||||
// Install v1
|
||||
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||
|
||||
// Create a source directory with v2 content
|
||||
srcDir := filepath.Join(tempDir, "v2-source")
|
||||
Expect(os.MkdirAll(srcDir, 0750)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte("#!/bin/sh\necho v2"), 0755)).To(Succeed())
|
||||
|
||||
// Gallery points to the v2 source dir
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: srcDir,
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Verify run.sh was updated
|
||||
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal("#!/bin/sh\necho v2"))
|
||||
|
||||
// Verify metadata was updated
|
||||
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
var meta BackendMetadata
|
||||
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||
Expect(meta.Version).To(Equal("2.0.0"))
|
||||
Expect(meta.Name).To(Equal("my-backend"))
|
||||
})
|
||||
|
||||
It("should restore backup on failure", func() {
|
||||
// Install v1
|
||||
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||
|
||||
// Gallery points to a nonexistent path (no run.sh will be found)
|
||||
nonExistentDir := filepath.Join(tempDir, "does-not-exist")
|
||||
writeGalleryYAML([]GalleryBackend{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "my-backend",
|
||||
},
|
||||
URI: nonExistentDir,
|
||||
Version: "2.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
// Verify v1 is still intact
|
||||
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal("#!/bin/sh\necho v1"))
|
||||
|
||||
// Verify metadata still says v1
|
||||
metaData, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "metadata.json"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
var meta BackendMetadata
|
||||
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||
Expect(meta.Version).To(Equal("1.0.0"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -391,6 +391,10 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOllamaRoutes(e, requestExtractor, application)
|
||||
if application.ApplicationConfig().OllamaAPIRootEndpoint {
|
||||
routes.RegisterOllamaRootEndpoint(e)
|
||||
}
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
|
||||
@@ -956,8 +956,7 @@ parameters:
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// A model called "bert" can be present in the model directory depending on the order of the tests
|
||||
Expect(len(models.Models)).To(BeNumerically(">=", 8))
|
||||
Expect(len(models.Models)).To(BeNumerically(">=", 7))
|
||||
})
|
||||
It("can generate completions via ggml", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
@@ -979,6 +978,42 @@ parameters:
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not duplicate the first content token in streaming chat completions", Label("llama-gguf", "llama-gguf-stream"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
var contentParts []string
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if len(chunk.Choices) > 0 {
|
||||
delta := chunk.Choices[0].Delta.Content
|
||||
if delta != "" {
|
||||
contentParts = append(contentParts, delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(contentParts).ToNot(BeEmpty(), "Expected streaming content tokens")
|
||||
// The first content token should appear exactly once.
|
||||
// A bug in grpc-server.cpp caused the role-init array element
|
||||
// to get the same ChatDelta stamped, duplicating the first token.
|
||||
if len(contentParts) >= 2 {
|
||||
Expect(contentParts[0]).ToNot(Equal(contentParts[1]),
|
||||
"First content token was duplicated: %v", contentParts[:2])
|
||||
}
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test only on linux")
|
||||
|
||||
@@ -15,23 +15,31 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// UpgradeInfoProvider is an interface for querying cached backend upgrade information.
|
||||
type UpgradeInfoProvider interface {
|
||||
GetAvailableUpgrades() map[string]gallery.UpgradeInfo
|
||||
TriggerCheck()
|
||||
}
|
||||
|
||||
type BackendEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendSystemPath string
|
||||
backendApplier *galleryop.GalleryService
|
||||
upgradeChecker UpgradeInfoProvider
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService {
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService, upgradeChecker UpgradeInfoProvider) BackendEndpointService {
|
||||
return BackendEndpointService{
|
||||
galleries: galleries,
|
||||
backendPath: systemState.Backend.BackendsPath,
|
||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||
backendApplier: backendApplier,
|
||||
upgradeChecker: upgradeChecker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +154,62 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpgradesEndpoint returns the cached backend upgrade information
|
||||
// @Summary Get available backend upgrades
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||
// @Router /backends/upgrades [get]
|
||||
func (mgs *BackendEndpointService) GetUpgradesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if mgs.upgradeChecker == nil {
|
||||
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||
}
|
||||
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUpgradesEndpoint forces an immediate upgrade check
|
||||
// @Summary Force backend upgrade check
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]gallery.UpgradeInfo "Response"
|
||||
// @Router /backends/upgrades/check [post]
|
||||
func (mgs *BackendEndpointService) CheckUpgradesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if mgs.upgradeChecker == nil {
|
||||
return c.JSON(200, map[string]gallery.UpgradeInfo{})
|
||||
}
|
||||
mgs.upgradeChecker.TriggerCheck()
|
||||
// Return current cached results (the triggered check runs async)
|
||||
return c.JSON(200, mgs.upgradeChecker.GetAvailableUpgrades())
|
||||
}
|
||||
}
|
||||
|
||||
// UpgradeBackendEndpoint triggers an upgrade for a specific backend
|
||||
// @Summary Upgrade a backend
|
||||
// @Tags backends
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/upgrade/{name} [post]
|
||||
func (mgs *BackendEndpointService) UpgradeBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: backendName,
|
||||
Galleries: mgs.galleries,
|
||||
Upgrade: true,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
|
||||
// @Summary List all available Backends
|
||||
// @Tags backends
|
||||
|
||||
@@ -180,27 +180,39 @@ func PatchConfigEndpoint(cl *config.ModelConfigLoader, _ *model.ModelLoader, app
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid JSON: " + err.Error()})
|
||||
}
|
||||
|
||||
existingJSON, err := json.Marshal(modelConfig)
|
||||
// Read the raw YAML from disk rather than serializing the in-memory config.
|
||||
// The in-memory config has SetDefaults() applied, which would persist
|
||||
// runtime-only defaults (top_p, temperature, mirostat, etc.) to the file.
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, map[string]any{"error": "config path not trusted: " + err.Error()})
|
||||
}
|
||||
|
||||
diskYAML, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal existing config"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to read config file: " + err.Error()})
|
||||
}
|
||||
|
||||
var existingMap map[string]any
|
||||
if err := json.Unmarshal(existingJSON, &existingMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to parse existing config"})
|
||||
if err := yaml.Unmarshal(diskYAML, &existingMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to parse existing config: " + err.Error()})
|
||||
}
|
||||
if existingMap == nil {
|
||||
existingMap = map[string]any{}
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&existingMap, patchMap, mergo.WithOverride); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to merge configs: " + err.Error()})
|
||||
}
|
||||
|
||||
mergedJSON, err := json.Marshal(existingMap)
|
||||
// Marshal once and reuse for both validation and writing
|
||||
yamlData, err := yaml.Marshal(existingMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal merged config"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal YAML"})
|
||||
}
|
||||
|
||||
var updatedConfig config.ModelConfig
|
||||
if err := json.Unmarshal(mergedJSON, &updatedConfig); err != nil {
|
||||
if err := yaml.Unmarshal(yamlData, &updatedConfig); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "merged config is invalid: " + err.Error()})
|
||||
}
|
||||
|
||||
@@ -212,16 +224,6 @@ func PatchConfigEndpoint(cl *config.ModelConfigLoader, _ *model.ModelLoader, app
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": errMsg})
|
||||
}
|
||||
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, map[string]any{"error": "config path not trusted: " + err.Error()})
|
||||
}
|
||||
|
||||
yamlData, err := yaml.Marshal(updatedConfig)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal YAML"})
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to write config file"})
|
||||
}
|
||||
|
||||
@@ -239,5 +239,54 @@ backend: llama-cpp
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(data)).To(ContainSubstring("vllm"))
|
||||
})
|
||||
|
||||
It("should not persist runtime defaults (SetDefaults values) to disk", func() {
|
||||
// Create a minimal pipeline config - no sampling params
|
||||
seedConfig := `name: gpt-realtime
|
||||
pipeline:
|
||||
vad: silero-vad
|
||||
transcription: whisper-base
|
||||
llm: llama3
|
||||
tts: piper
|
||||
`
|
||||
configPath := filepath.Join(tempDir, "gpt-realtime.yaml")
|
||||
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
// PATCH with a small change to the pipeline
|
||||
body := bytes.NewBufferString(`{"pipeline": {"tts": "vibevoice"}}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/gpt-realtime", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Read the file from disk and verify no spurious defaults leaked
|
||||
data, err := os.ReadFile(configPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
fileContent := string(data)
|
||||
|
||||
// The patched value should be present
|
||||
Expect(fileContent).To(ContainSubstring("vibevoice"))
|
||||
|
||||
// Runtime-only defaults from SetDefaults() should NOT be in the file
|
||||
Expect(fileContent).NotTo(ContainSubstring("top_p"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("top_k"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("temperature"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mirostat"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mmap"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("mmlock"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("threads"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("low_vram"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("embeddings"))
|
||||
Expect(fileContent).NotTo(ContainSubstring("f16"))
|
||||
|
||||
// Original fields should still be present
|
||||
Expect(fileContent).To(ContainSubstring("gpt-realtime"))
|
||||
Expect(fileContent).To(ContainSubstring("silero-vad"))
|
||||
Expect(fileContent).To(ContainSubstring("whisper-base"))
|
||||
Expect(fileContent).To(ContainSubstring("llama3"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -37,7 +39,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
||||
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -46,12 +48,18 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
Detections: make([]schema.Detection, len(res.Detections)),
|
||||
}
|
||||
for i, detection := range res.Detections {
|
||||
var mask string
|
||||
if len(detection.Mask) > 0 {
|
||||
mask = base64.StdEncoding.EncodeToString(detection.Mask)
|
||||
}
|
||||
response.Detections[i] = schema.Detection{
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
Confidence: detection.Confidence,
|
||||
Mask: mask,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -119,48 +119,20 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
// Detect format once and reuse for both typed and map parsing
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
var modelConfig config.ModelConfig
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
isJSON := strings.Contains(contentType, "application/json") ||
|
||||
(!strings.Contains(contentType, "yaml") && len(trimmed) > 0 && trimmed[0] == '{')
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
var modelConfig config.ModelConfig
|
||||
if isJSON {
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse JSON: " + err.Error()})
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: "Failed to parse YAML: " + err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +145,9 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
modelConfig.SetDefaults(appConfig.ToConfigLoaderOptions()...)
|
||||
|
||||
// Validate the configuration
|
||||
// Validate without calling SetDefaults() — runtime defaults should not
|
||||
// be persisted to disk. SetDefaults() is called when loading configs
|
||||
// for inference via LoadModelConfigsFromPath().
|
||||
if valid, _ := modelConfig.Validate(); !valid {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
@@ -195,8 +166,21 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
yamlData, err := yaml.Marshal(&modelConfig)
|
||||
// Write only the user-provided fields to disk by parsing the original
|
||||
// body into a map (not the typed struct, which includes Go zero values).
|
||||
var bodyMap map[string]any
|
||||
if isJSON {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
} else {
|
||||
_ = yaml.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
|
||||
var yamlData []byte
|
||||
if bodyMap != nil {
|
||||
yamlData, err = yaml.Marshal(bodyMap)
|
||||
} else {
|
||||
yamlData, err = yaml.Marshal(&modelConfig)
|
||||
}
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
|
||||
144
core/http/endpoints/localai/pin_model.go
Normal file
144
core/http/endpoints/localai/pin_model.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TogglePinnedModelEndpoint handles pinning or unpinning a model.
|
||||
// Pinned models are excluded from idle unloading, LRU eviction, and memory-pressure eviction.
|
||||
//
|
||||
// @Summary Toggle model pinned status
|
||||
// @Description Pin or unpin a model. Pinned models stay loaded and are excluded from automatic eviction.
|
||||
// @Tags config
|
||||
// @Param name path string true "Model name"
|
||||
// @Param action path string true "Action: 'pin' or 'unpin'"
|
||||
// @Success 200 {object} ModelResponse
|
||||
// @Failure 400 {object} ModelResponse
|
||||
// @Failure 404 {object} ModelResponse
|
||||
// @Failure 500 {object} ModelResponse
|
||||
// @Router /api/models/toggle-pinned/{name}/{action} [put]
|
||||
func TogglePinnedModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, syncPinnedFn func()) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
})
|
||||
}
|
||||
|
||||
action := c.Param("action")
|
||||
if action != "pin" && action != "unpin" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Action must be 'pin' or 'unpin'",
|
||||
})
|
||||
}
|
||||
|
||||
// Get existing model config
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Get the config file path
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if configPath == "" {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the path is trusted
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Read the existing config file
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Parse the YAML config as a generic map to preserve all fields
|
||||
var configMap map[string]interface{}
|
||||
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update the pinned field
|
||||
pinned := action == "pin"
|
||||
if pinned {
|
||||
configMap["pinned"] = true
|
||||
} else {
|
||||
// Remove the pinned key entirely when unpinning (clean YAML)
|
||||
delete(configMap, "pinned")
|
||||
}
|
||||
|
||||
// Marshal back to YAML
|
||||
updatedData, err := yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to serialize configuration: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Write updated config back to file
|
||||
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Reload model configurations from disk
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Sync pinned models to the watchdog
|
||||
if syncPinnedFn != nil {
|
||||
syncPinnedFn()
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Model '%s' has been %sned successfully.", modelName, action)
|
||||
if pinned {
|
||||
msg += " The model will be excluded from automatic eviction."
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ModelResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Filename: configPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
148
core/http/endpoints/localai/toggle_model.go
Normal file
148
core/http/endpoints/localai/toggle_model.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ToggleModelEndpoint handles enabling or disabling a model from being loaded on demand.
|
||||
// When disabled, the model remains in the collection but will not be loaded when requested.
|
||||
//
|
||||
// @Summary Toggle model enabled/disabled status
|
||||
// @Description Enable or disable a model from being loaded on demand. Disabled models remain installed but cannot be loaded.
|
||||
// @Tags config
|
||||
// @Param name path string true "Model name"
|
||||
// @Param action path string true "Action: 'enable' or 'disable'"
|
||||
// @Success 200 {object} ModelResponse
|
||||
// @Failure 400 {object} ModelResponse
|
||||
// @Failure 404 {object} ModelResponse
|
||||
// @Failure 500 {object} ModelResponse
|
||||
// @Router /api/models/{name}/{action} [put]
|
||||
func ToggleStateModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
})
|
||||
}
|
||||
|
||||
action := c.Param("action")
|
||||
if action != "enable" && action != "disable" {
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Action must be 'enable' or 'disable'",
|
||||
})
|
||||
}
|
||||
|
||||
// Get existing model config
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Get the config file path
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if configPath == "" {
|
||||
return c.JSON(http.StatusNotFound, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the path is trusted
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Read the existing config file
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Parse the YAML config as a generic map to preserve all fields
|
||||
var configMap map[string]interface{}
|
||||
if err := yaml.Unmarshal(configData, &configMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update the disabled field
|
||||
disabled := action == "disable"
|
||||
if disabled {
|
||||
configMap["disabled"] = true
|
||||
} else {
|
||||
// Remove the disabled key entirely when enabling (clean YAML)
|
||||
delete(configMap, "disabled")
|
||||
}
|
||||
|
||||
// Marshal back to YAML
|
||||
updatedData, err := yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to serialize configuration: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Write updated config back to file
|
||||
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Reload model configurations from disk
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// If disabling, also shutdown the model if it's currently running
|
||||
if disabled {
|
||||
if err := ml.ShutdownModel(modelName); err != nil {
|
||||
// Log but don't fail - the config was saved successfully
|
||||
fmt.Printf("Warning: Failed to shutdown model '%s' during disable: %v\n", modelName, err)
|
||||
}
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Model '%s' has been %sd successfully.", modelName, action)
|
||||
if disabled {
|
||||
msg += " The model will not be loaded on demand until re-enabled."
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ModelResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Filename: configPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
153
core/http/endpoints/ollama/chat.go
Normal file
153
core/http/endpoints/ollama/chat.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ChatEndpoint handles Ollama-compatible /api/chat requests
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaChatRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
// Apply Ollama options to config
|
||||
applyOllamaOptions(input.Options, cfg)
|
||||
|
||||
// Convert Ollama messages to OpenAI format
|
||||
openAIMessages := ollamaMessagesToOpenAI(input.Messages)
|
||||
|
||||
// Build an OpenAI-compatible request
|
||||
openAIReq := &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
|
||||
},
|
||||
Messages: openAIMessages,
|
||||
Stream: input.IsStream(),
|
||||
Context: input.Context,
|
||||
Cancel: input.Cancel,
|
||||
}
|
||||
|
||||
if input.Options != nil {
|
||||
openAIReq.Temperature = input.Options.Temperature
|
||||
openAIReq.TopP = input.Options.TopP
|
||||
openAIReq.TopK = input.Options.TopK
|
||||
openAIReq.RepeatPenalty = input.Options.RepeatPenalty
|
||||
if input.Options.NumPredict != nil {
|
||||
openAIReq.Maxtokens = input.Options.NumPredict
|
||||
}
|
||||
if len(input.Options.Stop) > 0 {
|
||||
openAIReq.Stop = input.Options.Stop
|
||||
}
|
||||
}
|
||||
|
||||
predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, nil, false)
|
||||
xlog.Debug("Ollama Chat - Prompt (after templating)", "prompt_len", len(predInput))
|
||||
|
||||
if input.IsStream() {
|
||||
return handleOllamaChatStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
|
||||
return handleOllamaChatNonStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
}
|
||||
|
||||
func handleOllamaChatNonStream(c echo.Context, input *schema.OllamaChatRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
startTime := time.Now()
|
||||
var result string
|
||||
|
||||
cb := func(s string, choices *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama chat inference failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{
|
||||
Role: "assistant",
|
||||
Content: result,
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
func handleOllamaChatStream(c echo.Context, input *schema.OllamaChatRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
chunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{
|
||||
Role: "assistant",
|
||||
Content: token,
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
return writeNDJSON(c, chunk)
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, choices *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama chat stream inference failed", "error", err)
|
||||
errChunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "error",
|
||||
}
|
||||
writeNDJSON(c, errChunk)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send final done message
|
||||
totalDuration := time.Since(startTime)
|
||||
finalChunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{Role: "assistant", Content: ""},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
writeNDJSON(c, finalChunk)
|
||||
|
||||
return nil
|
||||
}
|
||||
67
core/http/endpoints/ollama/embed.go
Normal file
67
core/http/endpoints/ollama/embed.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// EmbedEndpoint handles Ollama-compatible /api/embed and /api/embeddings requests
|
||||
func EmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaEmbedRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
inputStrings := input.GetInputStrings()
|
||||
if len(inputStrings) == 0 {
|
||||
return ollamaError(c, 400, "input is required")
|
||||
}
|
||||
|
||||
var allEmbeddings [][]float32
|
||||
promptEvalCount := 0
|
||||
|
||||
for _, s := range inputStrings {
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama embed failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("embedding failed: %v", err))
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
xlog.Error("Ollama embed computation failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("embedding computation failed: %v", err))
|
||||
}
|
||||
|
||||
allEmbeddings = append(allEmbeddings, embeddings)
|
||||
// Rough token count estimate
|
||||
promptEvalCount += len(s) / 4
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaEmbedResponse{
|
||||
Model: input.Model,
|
||||
Embeddings: allEmbeddings,
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: promptEvalCount,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
179
core/http/endpoints/ollama/generate.go
Normal file
179
core/http/endpoints/ollama/generate.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// GenerateEndpoint handles Ollama-compatible /api/generate requests
|
||||
func GenerateEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaGenerateRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
// Handle empty prompt — return immediately with "load" reason
|
||||
if input.Prompt == "" {
|
||||
resp := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: "",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if input.IsStream() {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
writeNDJSON(c, resp)
|
||||
return nil
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
applyOllamaOptions(input.Options, cfg)
|
||||
|
||||
// Build messages from prompt
|
||||
var messages []schema.Message
|
||||
if input.System != "" {
|
||||
messages = append(messages, schema.Message{
|
||||
Role: "system",
|
||||
StringContent: input.System,
|
||||
Content: input.System,
|
||||
})
|
||||
}
|
||||
messages = append(messages, schema.Message{
|
||||
Role: "user",
|
||||
StringContent: input.Prompt,
|
||||
Content: input.Prompt,
|
||||
})
|
||||
|
||||
openAIReq := &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
|
||||
},
|
||||
Messages: messages,
|
||||
Stream: input.IsStream(),
|
||||
Context: input.Ctx,
|
||||
Cancel: input.Cancel,
|
||||
}
|
||||
|
||||
if input.Options != nil {
|
||||
openAIReq.Temperature = input.Options.Temperature
|
||||
openAIReq.TopP = input.Options.TopP
|
||||
openAIReq.TopK = input.Options.TopK
|
||||
openAIReq.RepeatPenalty = input.Options.RepeatPenalty
|
||||
if input.Options.NumPredict != nil {
|
||||
openAIReq.Maxtokens = input.Options.NumPredict
|
||||
}
|
||||
if len(input.Options.Stop) > 0 {
|
||||
openAIReq.Stop = input.Options.Stop
|
||||
}
|
||||
}
|
||||
|
||||
var predInput string
|
||||
if input.Raw {
|
||||
// Raw mode: skip chat template, use prompt directly
|
||||
predInput = input.Prompt
|
||||
} else {
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, nil, false)
|
||||
}
|
||||
xlog.Debug("Ollama Generate - Prompt", "prompt_len", len(predInput), "raw", input.Raw)
|
||||
|
||||
if input.IsStream() {
|
||||
return handleOllamaGenerateStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
|
||||
return handleOllamaGenerateNonStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
}
|
||||
|
||||
func handleOllamaGenerateNonStream(c echo.Context, input *schema.OllamaGenerateRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
startTime := time.Now()
|
||||
var result string
|
||||
|
||||
cb := func(s string, choices *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama generate inference failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: result,
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
func handleOllamaGenerateStream(c echo.Context, input *schema.OllamaGenerateRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
chunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: token,
|
||||
Done: false,
|
||||
}
|
||||
return writeNDJSON(c, chunk)
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, choices *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama generate stream inference failed", "error", err)
|
||||
errChunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "error",
|
||||
}
|
||||
writeNDJSON(c, errChunk)
|
||||
return nil
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
finalChunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: "",
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
writeNDJSON(c, finalChunk)
|
||||
|
||||
return nil
|
||||
}
|
||||
83
core/http/endpoints/ollama/helpers.go
Normal file
83
core/http/endpoints/ollama/helpers.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// writeNDJSON writes a JSON object followed by a newline to the response (NDJSON format)
|
||||
func writeNDJSON(c echo.Context, v any) bool {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal NDJSON", "error", err)
|
||||
return false
|
||||
}
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "%s\n", data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Response().Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
// ollamaError sends an Ollama-compatible JSON error response
|
||||
func ollamaError(c echo.Context, statusCode int, message string) error {
|
||||
return c.JSON(statusCode, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
// applyOllamaOptions applies Ollama options to the model configuration
|
||||
func applyOllamaOptions(opts *schema.OllamaOptions, cfg *config.ModelConfig) {
|
||||
if opts == nil {
|
||||
return
|
||||
}
|
||||
if opts.Temperature != nil {
|
||||
cfg.Temperature = opts.Temperature
|
||||
}
|
||||
if opts.TopP != nil {
|
||||
cfg.TopP = opts.TopP
|
||||
}
|
||||
if opts.TopK != nil {
|
||||
cfg.TopK = opts.TopK
|
||||
}
|
||||
if opts.NumPredict != nil {
|
||||
cfg.Maxtokens = opts.NumPredict
|
||||
}
|
||||
if opts.RepeatPenalty != 0 {
|
||||
cfg.RepeatPenalty = opts.RepeatPenalty
|
||||
}
|
||||
if opts.RepeatLastN != 0 {
|
||||
cfg.RepeatLastN = opts.RepeatLastN
|
||||
}
|
||||
if len(opts.Stop) > 0 {
|
||||
cfg.StopWords = append(cfg.StopWords, opts.Stop...)
|
||||
}
|
||||
if opts.NumCtx > 0 {
|
||||
cfg.ContextSize = &opts.NumCtx
|
||||
}
|
||||
}
|
||||
|
||||
// ollamaMessagesToOpenAI converts Ollama messages to OpenAI-compatible messages
|
||||
func ollamaMessagesToOpenAI(messages []schema.OllamaMessage) []schema.Message {
|
||||
var result []schema.Message
|
||||
for _, msg := range messages {
|
||||
openAIMsg := schema.Message{
|
||||
Role: msg.Role,
|
||||
StringContent: msg.Content,
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
// Convert base64 images to data URIs
|
||||
for _, img := range msg.Images {
|
||||
dataURI := fmt.Sprintf("data:image/png;base64,%s", img)
|
||||
openAIMsg.StringImages = append(openAIMsg.StringImages, dataURI)
|
||||
}
|
||||
|
||||
result = append(result, openAIMsg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
142
core/http/endpoints/ollama/models.go
Normal file
142
core/http/endpoints/ollama/models.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
const ollamaCompatVersion = "0.9.0"
|
||||
|
||||
// ListModelsEndpoint handles Ollama-compatible GET /api/tags
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelNames, err := galleryop.ListModels(bcl, ml, nil, galleryop.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
return ollamaError(c, 500, fmt.Sprintf("failed to list models: %v", err))
|
||||
}
|
||||
|
||||
var models []schema.OllamaModelEntry
|
||||
for _, name := range modelNames {
|
||||
ollamaName := name
|
||||
if !strings.Contains(ollamaName, ":") {
|
||||
ollamaName += ":latest"
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(name)))
|
||||
|
||||
entry := schema.OllamaModelEntry{
|
||||
Name: ollamaName,
|
||||
Model: ollamaName,
|
||||
ModifiedAt: time.Now().UTC(),
|
||||
Size: 0,
|
||||
Digest: digest,
|
||||
Details: modelDetailsFromConfig(bcl, name),
|
||||
}
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.OllamaListResponse{Models: models})
|
||||
}
|
||||
}
|
||||
|
||||
// ShowModelEndpoint handles Ollama-compatible POST /api/show
|
||||
func ShowModelEndpoint(bcl *config.ModelConfigLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.OllamaShowRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return ollamaError(c, 400, "invalid request body")
|
||||
}
|
||||
|
||||
name := req.Name
|
||||
if name == "" {
|
||||
name = req.Model
|
||||
}
|
||||
if name == "" {
|
||||
return ollamaError(c, 400, "name is required")
|
||||
}
|
||||
|
||||
// Strip tag suffix for config lookup
|
||||
configName := strings.Split(name, ":")[0]
|
||||
|
||||
cfg, exists := bcl.GetModelConfig(configName)
|
||||
if !exists {
|
||||
return ollamaError(c, 404, fmt.Sprintf("model '%s' not found", name))
|
||||
}
|
||||
|
||||
resp := schema.OllamaShowResponse{
|
||||
Modelfile: fmt.Sprintf("FROM %s", cfg.Model),
|
||||
Parameters: "",
|
||||
Template: cfg.TemplateConfig.Chat,
|
||||
Details: modelDetailsFromModelConfig(&cfg),
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListRunningEndpoint handles Ollama-compatible GET /api/ps
|
||||
func ListRunningEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
|
||||
var models []schema.OllamaPsEntry
|
||||
for _, m := range loadedModels {
|
||||
name := m.ID
|
||||
ollamaName := name
|
||||
if !strings.Contains(ollamaName, ":") {
|
||||
ollamaName += ":latest"
|
||||
}
|
||||
|
||||
entry := schema.OllamaPsEntry{
|
||||
Name: ollamaName,
|
||||
Model: ollamaName,
|
||||
Size: 0,
|
||||
Digest: fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(name))),
|
||||
Details: modelDetailsFromConfig(bcl, name),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
|
||||
SizeVRAM: 0,
|
||||
}
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.OllamaPsResponse{Models: models})
|
||||
}
|
||||
}
|
||||
|
||||
// VersionEndpoint handles Ollama-compatible GET /api/version
|
||||
func VersionEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.OllamaVersionResponse{Version: ollamaCompatVersion})
|
||||
}
|
||||
}
|
||||
|
||||
// HeartbeatEndpoint handles the Ollama root health check
|
||||
func HeartbeatEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.String(200, "Ollama is running")
|
||||
}
|
||||
}
|
||||
|
||||
func modelDetailsFromConfig(bcl *config.ModelConfigLoader, name string) schema.OllamaModelDetails {
|
||||
configName := strings.Split(name, ":")[0]
|
||||
cfg, exists := bcl.GetModelConfig(configName)
|
||||
if !exists {
|
||||
return schema.OllamaModelDetails{}
|
||||
}
|
||||
return modelDetailsFromModelConfig(&cfg)
|
||||
}
|
||||
|
||||
func modelDetailsFromModelConfig(cfg *config.ModelConfig) schema.OllamaModelDetails {
|
||||
return schema.OllamaModelDetails{
|
||||
Format: "gguf",
|
||||
Family: cfg.Backend,
|
||||
}
|
||||
}
|
||||
62
core/http/endpoints/ollama/models_test.go
Normal file
62
core/http/endpoints/ollama/models_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package ollama_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/ollama"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestOllamaEndpoints(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Ollama Endpoints Suite")
|
||||
}
|
||||
|
||||
var _ = Describe("Ollama endpoint handlers", func() {
|
||||
var e *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
})
|
||||
|
||||
Describe("HeartbeatEndpoint", func() {
|
||||
It("returns 'Ollama is running' on GET /", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.HeartbeatEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(Equal("Ollama is running"))
|
||||
})
|
||||
|
||||
It("returns 200 on HEAD /", func() {
|
||||
req := httptest.NewRequest(http.MethodHead, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.HeartbeatEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("VersionEndpoint", func() {
|
||||
It("returns a JSON object with version field", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.VersionEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(ContainSubstring(`"version"`))
|
||||
Expect(rec.Body.String()).To(MatchRegexp(`\d+\.\d+\.\d+`))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -265,55 +265,52 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
lastEmittedCount = len(partialResults)
|
||||
}
|
||||
} else {
|
||||
// Try JSON tool call parsing for streaming
|
||||
// Check if the result looks like JSON tool calls
|
||||
// Try JSON tool call parsing for streaming.
|
||||
// Only emit NEW tool calls (same guard as XML parser above).
|
||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||
if jsonErr == nil && len(jsonResults) > 0 {
|
||||
// Check if these are tool calls (have "name" and optionally "arguments")
|
||||
for _, jsonObj := range jsonResults {
|
||||
if name, ok := jsonObj["name"].(string); ok && name != "" {
|
||||
// This looks like a tool call
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||
jsonObj := jsonResults[i]
|
||||
name, ok := jsonObj["name"].(string)
|
||||
if !ok || name == "" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
// Emit tool call
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: lastEmittedCount,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
select {
|
||||
case responses <- initialMessage:
|
||||
default:
|
||||
}
|
||||
lastEmittedCount++
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
}
|
||||
lastEmittedCount = len(jsonResults)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -426,10 +423,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
toolCallID = id
|
||||
}
|
||||
|
||||
if i < lastEmittedCount {
|
||||
// Already emitted during streaming by the incremental
|
||||
// JSON/XML parser — skip to avoid duplicate tool calls.
|
||||
continue
|
||||
}
|
||||
|
||||
// Tool call not yet emitted — send name + args (two chunks).
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
@@ -454,7 +458,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
@@ -1045,6 +1049,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else if deltaContent := functions.ContentFromChatDeltas(chatDeltas); len(chatDeltas) > 0 && deltaContent != "" {
|
||||
// ChatDeltas have content but no tool calls — model answered without using tools.
|
||||
// This happens with thinking models (e.g. Gemma 4) where the Go-side reasoning
|
||||
// extraction misclassifies clean content as reasoning, leaving cbRawResult empty.
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser content (no tool calls)", "content_len", len(deltaContent))
|
||||
textContentToReturn = deltaContent
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
@@ -1067,7 +1078,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
// Use textContentToReturn if available (e.g. from ChatDeltas),
|
||||
// otherwise fall back to cbRawResult for legacy Go-side parsing.
|
||||
questionInput := cbRawResult
|
||||
if textContentToReturn != "" {
|
||||
questionInput = textContentToReturn
|
||||
}
|
||||
qResult, qErr := handleQuestion(config, funcResults, questionInput, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
}
|
||||
|
||||
@@ -143,12 +143,15 @@ func ComputeChoices(
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.).
|
||||
// When the C++ autoparser is active, it clears the raw response
|
||||
// and delivers data via ChatDeltas. If the response is empty but
|
||||
// ChatDeltas contain actionable data, skip the caller retry —
|
||||
// the autoparser already parsed the response successfully.
|
||||
// When the C++ autoparser is active, it may deliver parsed data
|
||||
// via ChatDeltas while also keeping the raw response. If ChatDeltas
|
||||
// contain actionable data (content or tool calls), skip the caller
|
||||
// retry — the autoparser already parsed the response successfully.
|
||||
// Note: we check ChatDeltas regardless of whether Response is empty,
|
||||
// because thinking models (e.g. Gemma 4) produce a non-empty Response
|
||||
// that the Go-side reasoning extraction can misclassify as reasoning-only.
|
||||
skipCallerRetry := false
|
||||
if strings.TrimSpace(prediction.Response) == "" && len(prediction.ChatDeltas) > 0 {
|
||||
if len(prediction.ChatDeltas) > 0 {
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
skipCallerRetry = true
|
||||
|
||||
@@ -242,11 +242,13 @@ var _ = Describe("ComputeChoices", func() {
|
||||
})
|
||||
|
||||
Context("chat deltas from latest attempt", func() {
|
||||
It("should return chat deltas from the last attempt only", func() {
|
||||
It("should return chat deltas from the last attempt when retry is allowed", func() {
|
||||
// When the first attempt has only reasoning (no content/tool calls),
|
||||
// the caller-driven retry proceeds and we get deltas from the last attempt.
|
||||
mockInference([]backend.LLMResponse{
|
||||
{
|
||||
Response: "retry-me",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "old"}},
|
||||
ChatDeltas: []*pb.ChatDelta{{ReasoningContent: "thinking..."}},
|
||||
},
|
||||
{
|
||||
Response: "final",
|
||||
@@ -266,6 +268,40 @@ var _ = Describe("ComputeChoices", func() {
|
||||
Expect(deltas).To(HaveLen(1))
|
||||
Expect(deltas[0].Content).To(Equal("new"))
|
||||
})
|
||||
|
||||
It("should keep first attempt deltas when ChatDeltas have content (skip retry)", func() {
|
||||
// When the first attempt has content in ChatDeltas, skipCallerRetry
|
||||
// prevents the retry — the autoparser already parsed successfully.
|
||||
mockInference([]backend.LLMResponse{
|
||||
{
|
||||
Response: "autoparser-content",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "first-content"}},
|
||||
},
|
||||
{
|
||||
Response: "should-not-reach",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "second-content"}},
|
||||
},
|
||||
})
|
||||
|
||||
retryRequested := false
|
||||
_, _, deltas, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
retryRequested = true
|
||||
return true
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(retryRequested).To(BeFalse(),
|
||||
"shouldRetry should not be called when ChatDeltas have content")
|
||||
Expect(deltas).To(HaveLen(1))
|
||||
Expect(deltas[0].Content).To(Equal("first-content"))
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
Context("result choices cleared on retry", func() {
|
||||
|
||||
@@ -167,6 +167,17 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the model is disabled
|
||||
if cfg != nil && cfg.IsDisabled() {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName),
|
||||
Code: http.StatusForbidden,
|
||||
Type: "model_disabled",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
|
||||
@@ -5,7 +5,16 @@
|
||||
"": {
|
||||
"name": "localai-react-ui",
|
||||
"dependencies": {
|
||||
"@codemirror/autocomplete": "^6.18.6",
|
||||
"@codemirror/commands": "^6.8.1",
|
||||
"@codemirror/lang-yaml": "^6.1.2",
|
||||
"@codemirror/language": "^6.11.0",
|
||||
"@codemirror/lint": "^6.8.5",
|
||||
"@codemirror/search": "^6.5.10",
|
||||
"@codemirror/state": "^6.5.2",
|
||||
"@codemirror/view": "^6.36.8",
|
||||
"@fortawesome/fontawesome-free": "^6.7.2",
|
||||
"@lezer/highlight": "^1.2.1",
|
||||
"@modelcontextprotocol/ext-apps": "^1.2.2",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"dompurify": "^3.2.5",
|
||||
@@ -14,6 +23,7 @@
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-router-dom": "^7.6.1",
|
||||
"yaml": "^2.8.3",
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.27.0",
|
||||
@@ -66,6 +76,22 @@
|
||||
|
||||
"@babel/types": ["@babel/types@7.29.0", "", { "dependencies": { "@babel/helper-string-parser": "^7.27.1", "@babel/helper-validator-identifier": "^7.28.5" } }, "sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A=="],
|
||||
|
||||
"@codemirror/autocomplete": ["@codemirror/autocomplete@6.20.1", "", { "dependencies": { "@codemirror/language": "^6.0.0", "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.17.0", "@lezer/common": "^1.0.0" } }, "sha512-1cvg3Vz1dSSToCNlJfRA2WSI4ht3K+WplO0UMOgmUYPivCyy2oueZY6Lx7M9wThm7SDUBViRmuT+OG/i8+ON9A=="],
|
||||
|
||||
"@codemirror/commands": ["@codemirror/commands@6.10.3", "", { "dependencies": { "@codemirror/language": "^6.0.0", "@codemirror/state": "^6.6.0", "@codemirror/view": "^6.27.0", "@lezer/common": "^1.1.0" } }, "sha512-JFRiqhKu+bvSkDLI+rUhJwSxQxYb759W5GBezE8Uc8mHLqC9aV/9aTC7yJSqCtB3F00pylrLCwnyS91Ap5ej4Q=="],
|
||||
|
||||
"@codemirror/lang-yaml": ["@codemirror/lang-yaml@6.1.3", "", { "dependencies": { "@codemirror/autocomplete": "^6.0.0", "@codemirror/language": "^6.0.0", "@codemirror/state": "^6.0.0", "@lezer/common": "^1.2.0", "@lezer/highlight": "^1.2.0", "@lezer/lr": "^1.0.0", "@lezer/yaml": "^1.0.0" } }, "sha512-AZ8DJBuXGVHybpBQhmZtgew5//4hv3tdkXnr3vDmOUMJRuB6vn/uuwtmTOTlqEaQFg3hQSVeA90NmvIQyUV6FQ=="],
|
||||
|
||||
"@codemirror/language": ["@codemirror/language@6.12.3", "", { "dependencies": { "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.23.0", "@lezer/common": "^1.5.0", "@lezer/highlight": "^1.0.0", "@lezer/lr": "^1.0.0", "style-mod": "^4.0.0" } }, "sha512-QwCZW6Tt1siP37Jet9Tb02Zs81TQt6qQrZR2H+eGMcFsL1zMrk2/b9CLC7/9ieP1fjIUMgviLWMmgiHoJrj+ZA=="],
|
||||
|
||||
"@codemirror/lint": ["@codemirror/lint@6.9.5", "", { "dependencies": { "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.35.0", "crelt": "^1.0.5" } }, "sha512-GElsbU9G7QT9xXhpUg1zWGmftA/7jamh+7+ydKRuT0ORpWS3wOSP0yT1FOlIZa7mIJjpVPipErsyvVqB9cfTFA=="],
|
||||
|
||||
"@codemirror/search": ["@codemirror/search@6.6.0", "", { "dependencies": { "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.37.0", "crelt": "^1.0.5" } }, "sha512-koFuNXcDvyyotWcgOnZGmY7LZqEOXZaaxD/j6n18TCLx2/9HieZJ5H6hs1g8FiRxBD0DNfs0nXn17g872RmYdw=="],
|
||||
|
||||
"@codemirror/state": ["@codemirror/state@6.6.0", "", { "dependencies": { "@marijn/find-cluster-break": "^1.0.0" } }, "sha512-4nbvra5R5EtiCzr9BTHiTLc+MLXK2QGiAVYMyi8PkQd3SR+6ixar/Q/01Fa21TBIDOZXgeWV4WppsQolSreAPQ=="],
|
||||
|
||||
"@codemirror/view": ["@codemirror/view@6.40.0", "", { "dependencies": { "@codemirror/state": "^6.6.0", "crelt": "^1.0.6", "style-mod": "^4.1.0", "w3c-keyname": "^2.2.4" } }, "sha512-WA0zdU7xfF10+5I3HhUUq3kqOx3KjqmtQ9lqZjfK7jtYk4G72YW9rezcSywpaUMCWOMlq+6E0pO1IWg1TNIhtg=="],
|
||||
|
||||
"@esbuild/aix-ppc64": ["@esbuild/aix-ppc64@0.25.12", "", { "os": "aix", "cpu": "ppc64" }, "sha512-Hhmwd6CInZ3dwpuGTF8fJG6yoWmsToE+vYgD4nytZVxcu1ulHpUQRAB1UJ8+N1Am3Mz4+xOByoQoSZf4D+CpkA=="],
|
||||
|
||||
"@esbuild/android-arm": ["@esbuild/android-arm@0.25.12", "", { "os": "android", "cpu": "arm" }, "sha512-VJ+sKvNA/GE7Ccacc9Cha7bpS8nyzVv0jdVgwNDaR4gDMC/2TTRc33Ip8qrNYUcpkOHUT5OZ0bUcNNVZQ9RLlg=="],
|
||||
@@ -158,6 +184,16 @@
|
||||
|
||||
"@jridgewell/trace-mapping": ["@jridgewell/trace-mapping@0.3.31", "", { "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", "@jridgewell/sourcemap-codec": "^1.4.14" } }, "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw=="],
|
||||
|
||||
"@lezer/common": ["@lezer/common@1.5.1", "", {}, "sha512-6YRVG9vBkaY7p1IVxL4s44n5nUnaNnGM2/AckNgYOnxTG2kWh1vR8BMxPseWPjRNpb5VtXnMpeYAEAADoRV1Iw=="],
|
||||
|
||||
"@lezer/highlight": ["@lezer/highlight@1.2.3", "", { "dependencies": { "@lezer/common": "^1.3.0" } }, "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g=="],
|
||||
|
||||
"@lezer/lr": ["@lezer/lr@1.4.8", "", { "dependencies": { "@lezer/common": "^1.0.0" } }, "sha512-bPWa0Pgx69ylNlMlPvBPryqeLYQjyJjqPx+Aupm5zydLIF3NE+6MMLT8Yi23Bd9cif9VS00aUebn+6fDIGBcDA=="],
|
||||
|
||||
"@lezer/yaml": ["@lezer/yaml@1.0.4", "", { "dependencies": { "@lezer/common": "^1.2.0", "@lezer/highlight": "^1.0.0", "@lezer/lr": "^1.4.0" } }, "sha512-2lrrHqxalACEbxIbsjhqGpSW8kWpUKuY6RHgnSAFZa6qK62wvnPxA8hGOwOoDbwHcOFs5M4o27mjGu+P7TvBmw=="],
|
||||
|
||||
"@marijn/find-cluster-break": ["@marijn/find-cluster-break@1.0.2", "", {}, "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g=="],
|
||||
|
||||
"@modelcontextprotocol/ext-apps": ["@modelcontextprotocol/ext-apps@1.2.2", "", { "peerDependencies": { "@modelcontextprotocol/sdk": "^1.24.0", "react": "^17.0.0 || ^18.0.0 || ^19.0.0", "react-dom": "^17.0.0 || ^18.0.0 || ^19.0.0", "zod": "^3.25.0 || ^4.0.0" }, "optionalPeers": ["react", "react-dom"] }, "sha512-qMnhIKb8tyPesl+kZU76Xz9Bi9putCO+LcgvBJ00fDdIniiLZsnQbAeTKoq+sTiYH1rba2Fvj8NPAFxij+gyxw=="],
|
||||
|
||||
"@modelcontextprotocol/sdk": ["@modelcontextprotocol/sdk@1.27.1", "", { "dependencies": { "@hono/node-server": "^1.19.9", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", "cors": "^2.8.5", "cross-spawn": "^7.0.5", "eventsource": "^3.0.2", "eventsource-parser": "^3.0.0", "express": "^5.2.1", "express-rate-limit": "^8.2.1", "hono": "^4.11.4", "jose": "^6.1.3", "json-schema-typed": "^8.0.2", "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", "zod": "^3.25 || ^4.0", "zod-to-json-schema": "^3.25.1" }, "peerDependencies": { "@cfworker/json-schema": "^4.1.1" }, "optionalPeers": ["@cfworker/json-schema"] }, "sha512-sr6GbP+4edBwFndLbM60gf07z0FQ79gaExpnsjMGePXqFcSSb7t6iscpjk9DhFhwd+mTEQrzNafGP8/iGGFYaA=="],
|
||||
@@ -286,6 +322,8 @@
|
||||
|
||||
"cors": ["cors@2.8.6", "", { "dependencies": { "object-assign": "^4", "vary": "^1" } }, "sha512-tJtZBBHA6vjIAaF6EnIaq6laBBP9aq/Y3ouVJjEfoHbRBcHBAHYcMh/w8LDrk2PvIMMq8gmopa5D4V8RmbrxGw=="],
|
||||
|
||||
"crelt": ["crelt@1.0.6", "", {}, "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g=="],
|
||||
|
||||
"cross-spawn": ["cross-spawn@7.0.6", "", { "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", "which": "^2.0.1" } }, "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA=="],
|
||||
|
||||
"debug": ["debug@4.4.3", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA=="],
|
||||
@@ -572,6 +610,8 @@
|
||||
|
||||
"strip-json-comments": ["strip-json-comments@3.1.1", "", {}, "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig=="],
|
||||
|
||||
"style-mod": ["style-mod@4.1.3", "", {}, "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ=="],
|
||||
|
||||
"supports-color": ["supports-color@7.2.0", "", { "dependencies": { "has-flag": "^4.0.0" } }, "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw=="],
|
||||
|
||||
"tinyglobby": ["tinyglobby@0.2.15", "", { "dependencies": { "fdir": "^6.5.0", "picomatch": "^4.0.3" } }, "sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ=="],
|
||||
@@ -592,6 +632,8 @@
|
||||
|
||||
"vite": ["vite@6.4.1", "", { "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.4.4", "picomatch": "^4.0.2", "postcss": "^8.5.3", "rollup": "^4.34.9", "tinyglobby": "^0.2.13" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", "jiti": ">=1.21.0", "less": "*", "lightningcss": "^1.21.0", "sass": "*", "sass-embedded": "*", "stylus": "*", "sugarss": "*", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "jiti", "less", "lightningcss", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g=="],
|
||||
|
||||
"w3c-keyname": ["w3c-keyname@2.2.8", "", {}, "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ=="],
|
||||
|
||||
"which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="],
|
||||
|
||||
"word-wrap": ["word-wrap@1.2.5", "", {}, "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA=="],
|
||||
@@ -600,6 +642,8 @@
|
||||
|
||||
"yallist": ["yallist@3.1.1", "", {}, "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="],
|
||||
|
||||
"yaml": ["yaml@2.8.3", "", { "bin": { "yaml": "bin.mjs" } }, "sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg=="],
|
||||
|
||||
"yocto-queue": ["yocto-queue@0.1.0", "", {}, "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q=="],
|
||||
|
||||
"zod": ["zod@4.3.6", "", {}, "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg=="],
|
||||
|
||||
191
core/http/react-ui/e2e/model-config.spec.js
Normal file
191
core/http/react-ui/e2e/model-config.spec.js
Normal file
@@ -0,0 +1,191 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
const MOCK_METADATA = {
|
||||
sections: [
|
||||
{ id: 'general', label: 'General', icon: 'settings', order: 0 },
|
||||
{ id: 'parameters', label: 'Parameters', icon: 'sliders', order: 20 },
|
||||
],
|
||||
fields: [
|
||||
{ path: 'name', yaml_key: 'name', go_type: 'string', ui_type: 'string', section: 'general', label: 'Model Name', description: 'Unique identifier for this model', component: 'input', order: 0 },
|
||||
{ path: 'backend', yaml_key: 'backend', go_type: 'string', ui_type: 'string', section: 'general', label: 'Backend', description: 'Inference backend to use', component: 'select', autocomplete_provider: 'backends', order: 10 },
|
||||
{ path: 'context_size', yaml_key: 'context_size', go_type: '*int', ui_type: 'int', section: 'general', label: 'Context Size', description: 'Maximum context window in tokens', component: 'number', vram_impact: true, order: 20 },
|
||||
{ path: 'cuda', yaml_key: 'cuda', go_type: 'bool', ui_type: 'bool', section: 'general', label: 'CUDA', description: 'Enable CUDA GPU acceleration', component: 'toggle', order: 30 },
|
||||
{ path: 'parameters.temperature', yaml_key: 'temperature', go_type: '*float64', ui_type: 'float', section: 'parameters', label: 'Temperature', description: 'Sampling temperature', component: 'slider', min: 0, max: 2, step: 0.1, order: 0 },
|
||||
{ path: 'parameters.top_p', yaml_key: 'top_p', go_type: '*float64', ui_type: 'float', section: 'parameters', label: 'Top P', description: 'Nucleus sampling threshold', component: 'slider', min: 0, max: 1, step: 0.05, order: 10 },
|
||||
],
|
||||
}
|
||||
|
||||
// Mock raw YAML (what the edit endpoint returns) — only fields actually in the file
|
||||
const MOCK_YAML = `name: mock-model
|
||||
backend: mock-backend
|
||||
parameters:
|
||||
model: mock-model.bin
|
||||
`
|
||||
|
||||
const MOCK_AUTOCOMPLETE_BACKENDS = { values: ['mock-backend', 'llama-cpp', 'vllm'] }
|
||||
|
||||
test.describe('Model Editor - Interactive Tab', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
// Mock config metadata
|
||||
await page.route('**/api/models/config-metadata*', (route) => {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(MOCK_METADATA),
|
||||
})
|
||||
})
|
||||
|
||||
// Mock raw YAML edit endpoint (GET for loading, POST for saving)
|
||||
await page.route('**/api/models/edit/mock-model', (route) => {
|
||||
if (route.request().method() === 'POST') {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ message: 'Configuration file saved' }),
|
||||
})
|
||||
} else {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ config: MOCK_YAML, name: 'mock-model' }),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Mock PATCH config-json for interactive save
|
||||
await page.route('**/api/models/config-json/mock-model', (route) => {
|
||||
if (route.request().method() === 'PATCH') {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ success: true, message: "Model 'mock-model' updated successfully" }),
|
||||
})
|
||||
} else {
|
||||
route.fulfill({ contentType: 'application/json', body: '{}' })
|
||||
}
|
||||
})
|
||||
|
||||
// Mock autocomplete for backends
|
||||
await page.route('**/api/models/config-metadata/autocomplete/backends', (route) => {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(MOCK_AUTOCOMPLETE_BACKENDS),
|
||||
})
|
||||
})
|
||||
|
||||
await page.goto('/app/model-editor/mock-model')
|
||||
// Wait for the page to load
|
||||
await expect(page.locator('h1', { hasText: 'Model Editor' })).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('page loads and shows model name in header', async ({ page }) => {
|
||||
await expect(page.locator('text=mock-model')).toBeVisible()
|
||||
await expect(page.locator('h1', { hasText: 'Model Editor' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('interactive tab is active by default', async ({ page }) => {
|
||||
// The field browser should be visible (interactive tab content)
|
||||
await expect(page.locator('input[placeholder="Search fields to add..."]')).toBeVisible()
|
||||
})
|
||||
|
||||
test('existing config fields from YAML are populated', async ({ page }) => {
|
||||
// The mock YAML has name and backend — they should be active fields
|
||||
await expect(page.locator('text=Model Name')).toBeVisible()
|
||||
await expect(page.locator('span', { hasText: /^Backend$/ }).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('section sidebar shows sections with active fields', async ({ page }) => {
|
||||
const sidebar = page.locator('nav')
|
||||
await expect(sidebar.locator('text=General')).toBeVisible()
|
||||
})
|
||||
|
||||
test('typing in field browser shows matching fields', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Temperature')
|
||||
await expect(page.locator('text=Temperature').first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('clicking a field result adds it to the config', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Temperature')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Temperature' }).first().click()
|
||||
await expect(page.locator('h3', { hasText: 'Parameters' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('toggle field renders a toggle switch', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('CUDA')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'CUDA' }).first().click()
|
||||
await expect(page.locator('text=CUDA').first()).toBeVisible()
|
||||
const cudaSection = page.locator('div', { has: page.locator('span', { hasText: /^CUDA$/ }) }).first()
|
||||
await expect(cudaSection.locator('input[type="checkbox"]')).toHaveCount(1)
|
||||
})
|
||||
|
||||
test('number field renders a numeric input', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Context Size')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Context Size' }).first().click()
|
||||
await expect(page.locator('input[type="number"]')).toBeVisible()
|
||||
})
|
||||
|
||||
test('changing a field value enables the Save button', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Context Size')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Context Size' }).first().click()
|
||||
const numberInput = page.locator('input[type="number"]')
|
||||
await numberInput.fill('4096')
|
||||
await expect(page.locator('button', { hasText: 'Save Changes' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('removing a field with X button removes it from the form', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Temperature')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Temperature' }).first().click()
|
||||
const paramsHeader = page.locator('h3', { hasText: 'Parameters' })
|
||||
await expect(paramsHeader).toBeVisible()
|
||||
const paramsSection = paramsHeader.locator('..')
|
||||
await paramsSection.locator('button[title="Remove field"]').first().click()
|
||||
await expect(paramsHeader).not.toBeVisible()
|
||||
})
|
||||
|
||||
test('save sends PATCH and shows success toast', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Context Size')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Context Size' }).first().click()
|
||||
const numberInput = page.locator('input[type="number"]')
|
||||
await numberInput.fill('8192')
|
||||
await page.locator('button', { hasText: 'Save Changes' }).click()
|
||||
await expect(page.locator('text=Configuration saved')).toBeVisible({ timeout: 5_000 })
|
||||
})
|
||||
|
||||
test('added field is no longer shown in field browser results', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Temperature')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Temperature' }).first().click()
|
||||
await searchInput.fill('Temperature')
|
||||
await page.waitForTimeout(200)
|
||||
const results = dropdown.locator('div[style*="cursor: pointer"]', { hasText: 'Temperature' })
|
||||
await expect(results).toHaveCount(0)
|
||||
})
|
||||
|
||||
test('switching to YAML tab shows code editor', async ({ page }) => {
|
||||
await page.locator('button', { hasText: 'YAML' }).click()
|
||||
// The CodeMirror editor should be visible
|
||||
await expect(page.locator('.cm-editor').first()).toBeVisible()
|
||||
// The field browser should NOT be visible
|
||||
await expect(page.locator('input[placeholder="Search fields to add..."]')).not.toBeVisible()
|
||||
})
|
||||
|
||||
test('switching back to Interactive tab restores fields', async ({ page }) => {
|
||||
// Go to YAML tab
|
||||
await page.locator('button', { hasText: 'YAML' }).click()
|
||||
await expect(page.locator('input[placeholder="Search fields to add..."]')).not.toBeVisible()
|
||||
// Go back to Interactive tab
|
||||
await page.locator('button', { hasText: 'Interactive' }).click()
|
||||
await expect(page.locator('input[placeholder="Search fields to add..."]')).toBeVisible()
|
||||
await expect(page.locator('text=Model Name')).toBeVisible()
|
||||
})
|
||||
})
|
||||
@@ -20,7 +20,17 @@
|
||||
"dompurify": "^3.2.5",
|
||||
"@fortawesome/fontawesome-free": "^6.7.2",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"@modelcontextprotocol/ext-apps": "^1.2.2"
|
||||
"@modelcontextprotocol/ext-apps": "^1.2.2",
|
||||
"yaml": "^2.8.3",
|
||||
"@codemirror/autocomplete": "^6.18.6",
|
||||
"@codemirror/commands": "^6.8.1",
|
||||
"@codemirror/lang-yaml": "^6.1.2",
|
||||
"@codemirror/language": "^6.11.0",
|
||||
"@codemirror/lint": "^6.8.5",
|
||||
"@codemirror/search": "^6.5.10",
|
||||
"@codemirror/state": "^6.5.2",
|
||||
"@codemirror/view": "^6.36.8",
|
||||
"@lezer/highlight": "^1.2.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitejs/plugin-react": "^4.5.2",
|
||||
|
||||
@@ -904,39 +904,16 @@
|
||||
box-shadow: 0 0 0 2px var(--color-primary-light);
|
||||
}
|
||||
|
||||
/* Code editor (syntax-highlighted textarea overlay) */
|
||||
.code-editor-highlight .hljs {
|
||||
background: transparent;
|
||||
padding: 0;
|
||||
/* CodeMirror editor wrapper */
|
||||
.code-editor-cm .cm-editor {
|
||||
border: 1px solid var(--color-border-default);
|
||||
border-radius: var(--radius-md);
|
||||
}
|
||||
.code-editor-wrapper textarea:focus {
|
||||
.code-editor-cm .cm-editor.cm-focused {
|
||||
border-color: var(--color-border-strong);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
/* highlight.js YAML syntax colours – dark theme */
|
||||
[data-theme="dark"] .hljs-attr { color: #7dd3fc; }
|
||||
[data-theme="dark"] .hljs-string { color: #6ee7b7; }
|
||||
[data-theme="dark"] .hljs-number { color: #fcd34d; }
|
||||
[data-theme="dark"] .hljs-literal { color: #f9a8d4; }
|
||||
[data-theme="dark"] .hljs-keyword { color: #c4b5fd; }
|
||||
[data-theme="dark"] .hljs-comment { color: #64748b; font-style: italic; }
|
||||
[data-theme="dark"] .hljs-meta { color: #94a3b8; }
|
||||
[data-theme="dark"] .hljs-bullet { color: #38bdf8; }
|
||||
[data-theme="dark"] .hljs-section { color: #a78bfa; font-weight: 600; }
|
||||
[data-theme="dark"] .hljs-type { color: #f472b6; }
|
||||
|
||||
/* highlight.js YAML syntax colours – light theme */
|
||||
[data-theme="light"] .hljs-attr { color: #0369a1; }
|
||||
[data-theme="light"] .hljs-string { color: #15803d; }
|
||||
[data-theme="light"] .hljs-number { color: #b45309; }
|
||||
[data-theme="light"] .hljs-literal { color: #be185d; }
|
||||
[data-theme="light"] .hljs-keyword { color: #7c3aed; }
|
||||
[data-theme="light"] .hljs-comment { color: #94a3b8; font-style: italic; }
|
||||
[data-theme="light"] .hljs-meta { color: #64748b; }
|
||||
[data-theme="light"] .hljs-bullet { color: #0284c7; }
|
||||
[data-theme="light"] .hljs-section { color: #6d28d9; font-weight: 600; }
|
||||
[data-theme="light"] .hljs-type { color: #db2777; }
|
||||
|
||||
/* Form groups */
|
||||
.form-group {
|
||||
margin-bottom: var(--spacing-md);
|
||||
@@ -1959,6 +1936,56 @@
|
||||
40% { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* Staging progress indicator (replaces thinking dots during model transfer) */
|
||||
.chat-staging-progress {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
min-width: 200px;
|
||||
max-width: 320px;
|
||||
}
|
||||
.chat-staging-label {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
.chat-staging-label i {
|
||||
color: var(--color-primary);
|
||||
}
|
||||
.chat-staging-detail {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.chat-staging-bar-container {
|
||||
flex: 1;
|
||||
height: 4px;
|
||||
background: var(--color-bg-tertiary);
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.chat-staging-bar {
|
||||
height: 100%;
|
||||
background: var(--color-primary);
|
||||
border-radius: 2px;
|
||||
transition: width 300ms ease;
|
||||
}
|
||||
.chat-staging-pct {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted);
|
||||
min-width: 32px;
|
||||
text-align: right;
|
||||
}
|
||||
.chat-staging-file {
|
||||
font-size: 0.7rem;
|
||||
color: var(--color-text-muted);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* Message completion flash */
|
||||
.chat-message-bubble {
|
||||
transition: border-color 300ms ease;
|
||||
|
||||
138
core/http/react-ui/src/components/AutocompleteInput.jsx
Normal file
138
core/http/react-ui/src/components/AutocompleteInput.jsx
Normal file
@@ -0,0 +1,138 @@
|
||||
import { useState, useEffect, useRef, useCallback } from 'react'
|
||||
import { useAutocomplete } from '../hooks/useAutocomplete'
|
||||
|
||||
export default function AutocompleteInput({ value, onChange, provider, placeholder = 'Type or select...', style }) {
|
||||
const { values, loading } = useAutocomplete(provider)
|
||||
const [query, setQuery] = useState('')
|
||||
const [open, setOpen] = useState(false)
|
||||
const [focusIndex, setFocusIndex] = useState(-1)
|
||||
const wrapperRef = useRef(null)
|
||||
const listRef = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
setQuery(value || '')
|
||||
}, [value])
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e) => {
|
||||
if (wrapperRef.current && !wrapperRef.current.contains(e.target)) setOpen(false)
|
||||
}
|
||||
document.addEventListener('mousedown', handler)
|
||||
return () => document.removeEventListener('mousedown', handler)
|
||||
}, [])
|
||||
|
||||
const filtered = values.filter(v =>
|
||||
v.toLowerCase().includes(query.toLowerCase())
|
||||
)
|
||||
|
||||
const enterTargetIndex = focusIndex >= 0 ? focusIndex
|
||||
: filtered.length > 0 ? 0
|
||||
: -1
|
||||
|
||||
const commit = useCallback((val) => {
|
||||
setQuery(val)
|
||||
onChange(val)
|
||||
setOpen(false)
|
||||
setFocusIndex(-1)
|
||||
}, [onChange])
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
if (!open && (e.key === 'ArrowDown' || e.key === 'ArrowUp')) {
|
||||
setOpen(true)
|
||||
return
|
||||
}
|
||||
if (!open && e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
commit(query)
|
||||
return
|
||||
}
|
||||
if (!open) return
|
||||
|
||||
if (e.key === 'ArrowDown') {
|
||||
e.preventDefault()
|
||||
setFocusIndex(i => Math.min(i + 1, filtered.length - 1))
|
||||
} else if (e.key === 'ArrowUp') {
|
||||
e.preventDefault()
|
||||
setFocusIndex(i => Math.max(i - 1, 0))
|
||||
} else if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
if (enterTargetIndex >= 0) {
|
||||
commit(filtered[enterTargetIndex])
|
||||
} else {
|
||||
commit(query)
|
||||
}
|
||||
} else if (e.key === 'Escape') {
|
||||
setOpen(false)
|
||||
setFocusIndex(-1)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (focusIndex >= 0 && listRef.current) {
|
||||
const item = listRef.current.children[focusIndex]
|
||||
if (item) item.scrollIntoView({ block: 'nearest' })
|
||||
}
|
||||
}, [focusIndex])
|
||||
|
||||
return (
|
||||
<div ref={wrapperRef} style={{ position: 'relative', ...style }}>
|
||||
<input
|
||||
className="input"
|
||||
aria-haspopup="listbox"
|
||||
aria-expanded={open}
|
||||
value={query}
|
||||
onChange={(e) => {
|
||||
setQuery(e.target.value)
|
||||
setOpen(true)
|
||||
setFocusIndex(-1)
|
||||
onChange(e.target.value)
|
||||
}}
|
||||
onFocus={() => setOpen(true)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={loading ? 'Loading...' : placeholder}
|
||||
style={{ width: '100%', fontSize: '0.8125rem' }}
|
||||
/>
|
||||
{open && !loading && filtered.length > 0 && (
|
||||
<div
|
||||
ref={listRef}
|
||||
role="listbox"
|
||||
style={{
|
||||
position: 'absolute', top: '100%', left: 0, right: 0, zIndex: 50,
|
||||
maxHeight: 220, overflowY: 'auto', marginTop: 2,
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)', boxShadow: 'var(--shadow-md)',
|
||||
animation: 'dropdownIn 120ms ease-out',
|
||||
}}
|
||||
>
|
||||
{filtered.map((v, i) => {
|
||||
const isEnterTarget = i === enterTargetIndex
|
||||
return (
|
||||
<div
|
||||
key={v}
|
||||
role="option"
|
||||
aria-selected={v === value}
|
||||
style={{
|
||||
padding: '6px 10px', fontSize: '0.8125rem', cursor: 'pointer',
|
||||
display: 'flex', alignItems: 'center', gap: '6px',
|
||||
color: v === value ? 'var(--color-primary)' : 'var(--color-text-primary)',
|
||||
fontWeight: v === value ? 600 : 400,
|
||||
background: (i === focusIndex || isEnterTarget) ? 'var(--color-bg-tertiary)' : 'transparent',
|
||||
}}
|
||||
onMouseEnter={() => setFocusIndex(i)}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault()
|
||||
commit(v)
|
||||
}}
|
||||
>
|
||||
<span style={{ flex: 1, overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>{v}</span>
|
||||
{isEnterTarget && (
|
||||
<span style={{ color: 'var(--color-text-muted)', fontSize: '0.75rem', flexShrink: 0 }}>↵</span>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,111 +1,99 @@
|
||||
import { useRef, useEffect, useCallback } from 'react'
|
||||
import hljs from 'highlight.js/lib/core'
|
||||
import yaml from 'highlight.js/lib/languages/yaml'
|
||||
import { useRef, useMemo } from 'react'
|
||||
import { keymap, lineNumbers, highlightActiveLineGutter, highlightActiveLine, drawSelection } from '@codemirror/view'
|
||||
import { EditorView } from '@codemirror/view'
|
||||
import { EditorState } from '@codemirror/state'
|
||||
import { yaml } from '@codemirror/lang-yaml'
|
||||
import { autocompletion } from '@codemirror/autocomplete'
|
||||
import { linter, lintGutter } from '@codemirror/lint'
|
||||
import { defaultKeymap, history, historyKeymap, indentWithTab } from '@codemirror/commands'
|
||||
import { searchKeymap, highlightSelectionMatches } from '@codemirror/search'
|
||||
import { indentOnInput, indentUnit, bracketMatching, foldGutter, foldKeymap } from '@codemirror/language'
|
||||
import YAML from 'yaml'
|
||||
import { useCodeMirror } from '../hooks/useCodeMirror'
|
||||
import { useTheme } from '../contexts/ThemeContext'
|
||||
import { getThemeExtension } from '../utils/cmTheme'
|
||||
import { createYamlCompletionSource } from '../utils/cmYamlComplete'
|
||||
|
||||
hljs.registerLanguage('yaml', yaml)
|
||||
|
||||
export default function CodeEditor({ value, onChange, disabled, minHeight = '500px' }) {
|
||||
const codeRef = useRef(null)
|
||||
const textareaRef = useRef(null)
|
||||
const preRef = useRef(null)
|
||||
|
||||
const highlight = useCallback(() => {
|
||||
if (!codeRef.current) return
|
||||
const result = hljs.highlight(value + '\n', { language: 'yaml', ignoreIllegals: true })
|
||||
codeRef.current.innerHTML = result.value
|
||||
}, [value])
|
||||
|
||||
useEffect(() => {
|
||||
highlight()
|
||||
}, [highlight])
|
||||
|
||||
const handleScroll = () => {
|
||||
if (preRef.current && textareaRef.current) {
|
||||
preRef.current.scrollTop = textareaRef.current.scrollTop
|
||||
preRef.current.scrollLeft = textareaRef.current.scrollLeft
|
||||
function yamlIssueToDiagnostic(issue, cmDoc, severity) {
|
||||
const len = cmDoc.length
|
||||
if (issue.linePos && issue.linePos[0]) {
|
||||
const startLine = Math.min(issue.linePos[0].line, cmDoc.lines)
|
||||
const from = cmDoc.line(startLine).from + issue.linePos[0].col - 1
|
||||
let to = from + 1
|
||||
if (issue.linePos[1]) {
|
||||
const endLine = Math.min(issue.linePos[1].line, cmDoc.lines)
|
||||
to = cmDoc.line(endLine).from + issue.linePos[1].col - 1
|
||||
}
|
||||
return { from: Math.min(from, len), to: Math.min(Math.max(to, from + 1), len), severity, message: issue.message.split('\n')[0] }
|
||||
}
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
if (e.key === 'Tab') {
|
||||
e.preventDefault()
|
||||
const ta = e.target
|
||||
const start = ta.selectionStart
|
||||
const end = ta.selectionEnd
|
||||
const newValue = value.substring(0, start) + ' ' + value.substring(end)
|
||||
onChange(newValue)
|
||||
requestAnimationFrame(() => {
|
||||
ta.selectionStart = ta.selectionEnd = start + 2
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="code-editor-wrapper" style={{ position: 'relative', minHeight, fontSize: '0.8125rem' }}>
|
||||
<pre
|
||||
ref={preRef}
|
||||
className="code-editor-highlight"
|
||||
aria-hidden="true"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0, left: 0, right: 0, bottom: 0,
|
||||
margin: 0,
|
||||
padding: 'var(--spacing-sm)',
|
||||
overflow: 'auto',
|
||||
pointerEvents: 'none',
|
||||
fontFamily: "'JetBrains Mono', 'Fira Code', monospace",
|
||||
fontSize: 'inherit',
|
||||
lineHeight: 1.5,
|
||||
tabSize: 2,
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordWrap: 'break-word',
|
||||
background: 'var(--color-bg-tertiary)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
border: '1px solid var(--color-border-default)',
|
||||
}}
|
||||
>
|
||||
<code
|
||||
ref={codeRef}
|
||||
className="language-yaml"
|
||||
style={{
|
||||
fontFamily: 'inherit',
|
||||
fontSize: 'inherit',
|
||||
lineHeight: 'inherit',
|
||||
padding: 0,
|
||||
background: 'transparent',
|
||||
}}
|
||||
/>
|
||||
</pre>
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onScroll={handleScroll}
|
||||
onKeyDown={handleKeyDown}
|
||||
disabled={disabled}
|
||||
spellCheck={false}
|
||||
style={{
|
||||
position: 'relative',
|
||||
width: '100%',
|
||||
minHeight,
|
||||
margin: 0,
|
||||
padding: 'var(--spacing-sm)',
|
||||
fontFamily: "'JetBrains Mono', 'Fira Code', monospace",
|
||||
fontSize: 'inherit',
|
||||
lineHeight: 1.5,
|
||||
tabSize: 2,
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordWrap: 'break-word',
|
||||
color: 'transparent',
|
||||
caretColor: 'var(--color-text-primary)',
|
||||
background: 'transparent',
|
||||
border: '1px solid var(--color-border-default)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
outline: 'none',
|
||||
resize: 'vertical',
|
||||
overflow: 'auto',
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
return { from: 0, to: Math.min(1, len), severity, message: issue.message.split('\n')[0] }
|
||||
}
|
||||
|
||||
const yamlLinter = linter(view => {
|
||||
const text = view.state.doc.toString()
|
||||
if (!text.trim()) return []
|
||||
const parsed = YAML.parseDocument(text, { strict: true, prettyErrors: true })
|
||||
const diagnostics = []
|
||||
for (const err of parsed.errors) {
|
||||
diagnostics.push(yamlIssueToDiagnostic(err, view.state.doc, 'error'))
|
||||
}
|
||||
for (const warn of parsed.warnings) {
|
||||
diagnostics.push(yamlIssueToDiagnostic(warn, view.state.doc, 'warning'))
|
||||
}
|
||||
return diagnostics
|
||||
})
|
||||
|
||||
export default function CodeEditor({ value, onChange, disabled, minHeight = '500px', fields }) {
|
||||
const containerRef = useRef(null)
|
||||
const { theme } = useTheme()
|
||||
|
||||
// Static extensions — only recreate when fields change
|
||||
const extensions = useMemo(() => {
|
||||
const exts = [
|
||||
yaml(),
|
||||
lineNumbers(),
|
||||
highlightActiveLineGutter(),
|
||||
highlightActiveLine(),
|
||||
drawSelection(),
|
||||
foldGutter(),
|
||||
indentOnInput(),
|
||||
bracketMatching(),
|
||||
highlightSelectionMatches(),
|
||||
yamlLinter,
|
||||
lintGutter(),
|
||||
history(),
|
||||
indentUnit.of(' '),
|
||||
EditorState.tabSize.of(2),
|
||||
keymap.of([
|
||||
indentWithTab,
|
||||
...defaultKeymap,
|
||||
...historyKeymap,
|
||||
...searchKeymap,
|
||||
...foldKeymap,
|
||||
]),
|
||||
EditorView.theme({
|
||||
'&': { minHeight },
|
||||
'.cm-scroller': { overflow: 'auto' },
|
||||
}),
|
||||
]
|
||||
|
||||
if (fields && fields.length > 0) {
|
||||
exts.push(autocompletion({
|
||||
override: [createYamlCompletionSource(fields)],
|
||||
activateOnTyping: true,
|
||||
}))
|
||||
}
|
||||
|
||||
return exts
|
||||
}, [minHeight, fields])
|
||||
|
||||
// Dynamic extensions — reconfigured via Compartments (preserves undo/cursor/scroll)
|
||||
const dynamicExtensions = useMemo(() => ({
|
||||
theme: getThemeExtension(theme),
|
||||
readOnly: [EditorState.readOnly.of(!!disabled), EditorView.editable.of(!disabled)],
|
||||
}), [theme, disabled])
|
||||
|
||||
useCodeMirror({ containerRef, value, onChange, extensions, dynamicExtensions })
|
||||
|
||||
return <div ref={containerRef} className="code-editor-cm" />
|
||||
}
|
||||
|
||||
373
core/http/react-ui/src/components/ConfigFieldRenderer.jsx
Normal file
373
core/http/react-ui/src/components/ConfigFieldRenderer.jsx
Normal file
@@ -0,0 +1,373 @@
|
||||
import { useState } from 'react'
|
||||
import SettingRow from './SettingRow'
|
||||
import Toggle from './Toggle'
|
||||
import SearchableSelect from './SearchableSelect'
|
||||
import SearchableModelSelect from './SearchableModelSelect'
|
||||
import AutocompleteInput from './AutocompleteInput'
|
||||
import CodeEditor from './CodeEditor'
|
||||
|
||||
// Map autocomplete provider to SearchableModelSelect capability
|
||||
const PROVIDER_TO_CAPABILITY = {
|
||||
'models:chat': 'FLAG_CHAT',
|
||||
'models:tts': 'FLAG_TTS',
|
||||
'models:transcript': 'FLAG_TRANSCRIPT',
|
||||
'models:vad': 'FLAG_VAD',
|
||||
}
|
||||
|
||||
function coerceValue(raw, uiType) {
|
||||
if (raw === '' || raw === null || raw === undefined) return raw
|
||||
if (uiType === 'int') return parseInt(raw, 10) || 0
|
||||
if (uiType === 'float') return parseFloat(raw) || 0
|
||||
return raw
|
||||
}
|
||||
|
||||
function StringListEditor({ value, onChange, options }) {
|
||||
const items = Array.isArray(value) ? value : []
|
||||
|
||||
const update = (index, val) => {
|
||||
const next = [...items]
|
||||
next[index] = val
|
||||
onChange(next)
|
||||
}
|
||||
const add = () => onChange([...items, ''])
|
||||
const remove = (index) => onChange(items.filter((_, i) => i !== index))
|
||||
|
||||
// When options are available, filter out already-selected values
|
||||
const availableOptions = options
|
||||
? options.filter(o => !items.includes(o.value))
|
||||
: null
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 4, width: '100%' }}>
|
||||
{items.map((item, i) => (
|
||||
<div key={i} style={{ display: 'flex', gap: 4, alignItems: 'center' }}>
|
||||
{options ? (
|
||||
<SearchableSelect
|
||||
value={item}
|
||||
onChange={val => update(i, val)}
|
||||
options={[
|
||||
// Include the current value so it shows as selected
|
||||
...(item ? [options.find(o => o.value === item) || { value: item, label: item }] : []),
|
||||
...availableOptions,
|
||||
]}
|
||||
placeholder="Select..."
|
||||
style={{ flex: 1 }}
|
||||
/>
|
||||
) : (
|
||||
<input className="input" value={item} onChange={e => update(i, e.target.value)}
|
||||
style={{ flex: 1, fontSize: '0.8125rem' }} />
|
||||
)}
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={() => remove(i)}
|
||||
style={{ padding: '2px 6px', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
{(!options || availableOptions.length > 0) && (
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={add}
|
||||
style={{ alignSelf: 'flex-start', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-plus" /> Add
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function MapEditor({ value, onChange }) {
|
||||
const entries = value && typeof value === 'object' && !Array.isArray(value)
|
||||
? Object.entries(value) : []
|
||||
|
||||
const update = (index, key, val) => {
|
||||
const next = [...entries]
|
||||
next[index] = [key, val]
|
||||
onChange(Object.fromEntries(next))
|
||||
}
|
||||
const add = () => onChange({ ...value, '': '' })
|
||||
const remove = (index) => {
|
||||
const next = entries.filter((_, i) => i !== index)
|
||||
onChange(Object.fromEntries(next))
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 4, width: '100%' }}>
|
||||
{entries.map(([k, v], i) => (
|
||||
<div key={i} style={{ display: 'flex', gap: 4, alignItems: 'center' }}>
|
||||
<input className="input" value={k} placeholder="key"
|
||||
onChange={e => update(i, e.target.value, v)}
|
||||
style={{ flex: 1, fontSize: '0.8125rem' }} />
|
||||
<input className="input" value={typeof v === 'string' ? v : JSON.stringify(v)} placeholder="value"
|
||||
onChange={e => update(i, k, e.target.value)}
|
||||
style={{ flex: 1, fontSize: '0.8125rem' }} />
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={() => remove(i)}
|
||||
style={{ padding: '2px 6px', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={add}
|
||||
style={{ alignSelf: 'flex-start', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-plus" /> Add
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function JsonEditor({ value, onChange }) {
|
||||
const [text, setText] = useState(() =>
|
||||
typeof value === 'string' ? value : JSON.stringify(value, null, 2) || ''
|
||||
)
|
||||
const [parseError, setParseError] = useState(null)
|
||||
|
||||
const handleChange = (val) => {
|
||||
setText(val)
|
||||
try {
|
||||
const parsed = JSON.parse(val)
|
||||
setParseError(null)
|
||||
onChange(parsed)
|
||||
} catch {
|
||||
setParseError('Invalid JSON')
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ width: '100%' }}>
|
||||
<textarea
|
||||
className="input"
|
||||
value={text}
|
||||
onChange={e => handleChange(e.target.value)}
|
||||
style={{ width: '100%', minHeight: 80, fontFamily: 'monospace', fontSize: '0.8125rem', resize: 'vertical' }}
|
||||
/>
|
||||
{parseError && <div style={{ color: 'var(--color-error)', fontSize: '0.75rem', marginTop: 2 }}>{parseError}</div>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function FieldLabel({ field }) {
|
||||
return (
|
||||
<span style={{ display: 'flex', alignItems: 'center', gap: 6 }}>
|
||||
{field.label}
|
||||
{field.vram_impact && (
|
||||
<span style={{ fontSize: '0.625rem', padding: '1px 4px', borderRadius: 'var(--radius-sm)',
|
||||
background: 'var(--color-warning-light, rgba(245,158,11,0.15))', color: 'var(--color-warning)' }}>
|
||||
VRAM
|
||||
</span>
|
||||
)}
|
||||
{field.advanced && (
|
||||
<span style={{ fontSize: '0.625rem', padding: '1px 4px', borderRadius: 'var(--radius-sm)',
|
||||
background: 'var(--color-bg-tertiary)', color: 'var(--color-text-muted)' }}>
|
||||
Advanced
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export default function ConfigFieldRenderer({ field, value, onChange, onRemove, annotation }) {
|
||||
const handleChange = (raw) => {
|
||||
onChange(coerceValue(raw, field.ui_type))
|
||||
}
|
||||
|
||||
const removeBtn = (
|
||||
<button type="button" onClick={() => onRemove(field.path)}
|
||||
title="Remove field"
|
||||
style={{
|
||||
background: 'none', border: 'none', cursor: 'pointer', padding: '2px 4px',
|
||||
color: 'var(--color-text-muted)', fontSize: '0.75rem',
|
||||
}}>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
)
|
||||
|
||||
const description = (
|
||||
<span style={{ display: 'flex', alignItems: 'center', gap: 4 }}>
|
||||
{field.description || field.path}
|
||||
{removeBtn}
|
||||
</span>
|
||||
)
|
||||
|
||||
const component = field.component
|
||||
|
||||
// Toggle
|
||||
if (component === 'toggle') {
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<Toggle checked={!!value} onChange={handleChange} />
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Model-select
|
||||
if (component === 'model-select') {
|
||||
const cap = PROVIDER_TO_CAPABILITY[field.autocomplete_provider] || undefined
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<SearchableModelSelect
|
||||
value={value || ''}
|
||||
onChange={handleChange}
|
||||
capability={cap}
|
||||
placeholder={field.placeholder || 'Select model...'}
|
||||
style={{ width: 220 }}
|
||||
/>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Select with autocomplete provider (dynamic)
|
||||
if ((component === 'select' || component === 'input') && field.autocomplete_provider) {
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<AutocompleteInput
|
||||
value={value || ''}
|
||||
onChange={handleChange}
|
||||
provider={field.autocomplete_provider}
|
||||
placeholder={field.placeholder || 'Type or select...'}
|
||||
style={{ width: 220 }}
|
||||
/>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Select with static options
|
||||
if (component === 'select' && field.options?.length > 0) {
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<SearchableSelect
|
||||
value={value || ''}
|
||||
onChange={handleChange}
|
||||
options={field.options.map(o => ({ value: o.value, label: o.label }))}
|
||||
placeholder={field.placeholder || 'Select...'}
|
||||
style={{ width: 220 }}
|
||||
/>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Slider
|
||||
if (component === 'slider') {
|
||||
const min = field.min ?? 0
|
||||
const max = field.max ?? 1
|
||||
const step = field.step ?? 0.1
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
<input type="range" min={min} max={max} step={step}
|
||||
value={value ?? min}
|
||||
onChange={e => handleChange(parseFloat(e.target.value))}
|
||||
style={{ width: 120 }}
|
||||
/>
|
||||
<span style={{ fontSize: '0.8125rem', minWidth: 40, textAlign: 'right', fontVariantNumeric: 'tabular-nums' }}>
|
||||
{value ?? min}
|
||||
</span>
|
||||
</div>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Number
|
||||
if (component === 'number') {
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<>
|
||||
<input className="input" type="number"
|
||||
value={value ?? ''}
|
||||
onChange={e => handleChange(e.target.value)}
|
||||
min={field.min} max={field.max} step={field.step}
|
||||
placeholder={field.placeholder}
|
||||
style={{ width: 120, fontSize: '0.8125rem' }}
|
||||
/>
|
||||
{annotation}
|
||||
</>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
|
||||
// Textarea
|
||||
if (component === 'textarea') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<textarea className="input" value={value || ''}
|
||||
onChange={e => handleChange(e.target.value)}
|
||||
placeholder={field.placeholder}
|
||||
style={{ width: '100%', minHeight: 80, fontSize: '0.8125rem', resize: 'vertical' }}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Code editor
|
||||
if (component === 'code-editor') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<CodeEditor value={value || ''} onChange={handleChange} minHeight="80px" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// String list
|
||||
if (component === 'string-list') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<StringListEditor value={value} onChange={handleChange} options={field.options?.length > 0 ? field.options : null} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// JSON editor
|
||||
if (component === 'json-editor') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<JsonEditor value={value} onChange={handleChange} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Map editor
|
||||
if (component === 'map-editor') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<MapEditor value={value} onChange={handleChange} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Default: text input
|
||||
return (
|
||||
<SettingRow label={<FieldLabel field={field} />} description={description}>
|
||||
<input className="input" value={value ?? ''}
|
||||
onChange={e => handleChange(e.target.value)}
|
||||
placeholder={field.placeholder}
|
||||
style={{ width: 220, fontSize: '0.8125rem' }}
|
||||
/>
|
||||
</SettingRow>
|
||||
)
|
||||
}
|
||||
172
core/http/react-ui/src/components/FieldBrowser.jsx
Normal file
172
core/http/react-ui/src/components/FieldBrowser.jsx
Normal file
@@ -0,0 +1,172 @@
|
||||
import { useState, useEffect, useRef, useMemo } from 'react'
|
||||
|
||||
export default function FieldBrowser({ fields, activeFieldPaths, onAddField }) {
|
||||
const [query, setQuery] = useState('')
|
||||
const [open, setOpen] = useState(false)
|
||||
const [focusIndex, setFocusIndex] = useState(-1)
|
||||
const wrapperRef = useRef(null)
|
||||
const listRef = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e) => {
|
||||
if (wrapperRef.current && !wrapperRef.current.contains(e.target)) setOpen(false)
|
||||
}
|
||||
document.addEventListener('mousedown', handler)
|
||||
return () => document.removeEventListener('mousedown', handler)
|
||||
}, [])
|
||||
|
||||
const available = useMemo(() =>
|
||||
fields.filter(f => !activeFieldPaths.has(f.path)),
|
||||
[fields, activeFieldPaths]
|
||||
)
|
||||
|
||||
const filtered = useMemo(() => {
|
||||
if (!query) return available.slice(0, 30)
|
||||
const q = query.toLowerCase()
|
||||
return available.filter(f =>
|
||||
f.label.toLowerCase().includes(q) ||
|
||||
f.path.toLowerCase().includes(q) ||
|
||||
(f.description || '').toLowerCase().includes(q) ||
|
||||
f.section.toLowerCase().includes(q)
|
||||
).slice(0, 30)
|
||||
}, [available, query])
|
||||
|
||||
const enterTargetIndex = focusIndex >= 0 ? focusIndex
|
||||
: filtered.length > 0 ? 0
|
||||
: -1
|
||||
|
||||
const handleSelect = (field) => {
|
||||
onAddField(field)
|
||||
setQuery('')
|
||||
setOpen(false)
|
||||
setFocusIndex(-1)
|
||||
}
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
if (!open && (e.key === 'ArrowDown' || e.key === 'ArrowUp')) {
|
||||
setOpen(true)
|
||||
return
|
||||
}
|
||||
if (!open) return
|
||||
|
||||
if (e.key === 'ArrowDown') {
|
||||
e.preventDefault()
|
||||
setFocusIndex(i => Math.min(i + 1, filtered.length - 1))
|
||||
} else if (e.key === 'ArrowUp') {
|
||||
e.preventDefault()
|
||||
setFocusIndex(i => Math.max(i - 1, 0))
|
||||
} else if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
if (enterTargetIndex >= 0) {
|
||||
handleSelect(filtered[enterTargetIndex])
|
||||
}
|
||||
} else if (e.key === 'Escape') {
|
||||
setOpen(false)
|
||||
setFocusIndex(-1)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (focusIndex >= 0 && listRef.current) {
|
||||
const item = listRef.current.children[focusIndex]
|
||||
if (item) item.scrollIntoView({ block: 'nearest' })
|
||||
}
|
||||
}, [focusIndex])
|
||||
|
||||
const sectionColors = {
|
||||
general: 'var(--color-primary)',
|
||||
llm: 'var(--color-accent)',
|
||||
parameters: 'var(--color-success)',
|
||||
templates: 'var(--color-warning)',
|
||||
functions: 'var(--color-info, var(--color-primary))',
|
||||
reasoning: 'var(--color-accent)',
|
||||
diffusers: 'var(--color-warning)',
|
||||
tts: 'var(--color-success)',
|
||||
pipeline: 'var(--color-accent)',
|
||||
grpc: 'var(--color-text-muted)',
|
||||
agent: 'var(--color-primary)',
|
||||
mcp: 'var(--color-accent)',
|
||||
other: 'var(--color-text-muted)',
|
||||
}
|
||||
|
||||
return (
|
||||
<div ref={wrapperRef} style={{ position: 'relative', marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ position: 'relative' }}>
|
||||
<i className="fas fa-search" style={{
|
||||
position: 'absolute', left: 10, top: '50%', transform: 'translateY(-50%)',
|
||||
color: 'var(--color-text-muted)', fontSize: '0.75rem', pointerEvents: 'none',
|
||||
}} />
|
||||
<input
|
||||
className="input"
|
||||
value={query}
|
||||
onChange={e => { setQuery(e.target.value); setOpen(true); setFocusIndex(-1) }}
|
||||
onFocus={() => setOpen(true)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Search fields to add..."
|
||||
style={{ width: '100%', paddingLeft: 32, fontSize: '0.8125rem' }}
|
||||
/>
|
||||
</div>
|
||||
{open && (
|
||||
<div
|
||||
ref={listRef}
|
||||
style={{
|
||||
position: 'absolute', top: '100%', left: 0, right: 0, zIndex: 100, marginTop: 4,
|
||||
maxHeight: 320, overflowY: 'auto',
|
||||
background: 'var(--color-bg-secondary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)', boxShadow: 'var(--shadow-md)',
|
||||
animation: 'dropdownIn 120ms ease-out',
|
||||
}}
|
||||
>
|
||||
{filtered.length === 0 ? (
|
||||
<div style={{ padding: '12px 16px', fontSize: '0.8125rem', color: 'var(--color-text-muted)', fontStyle: 'italic' }}>
|
||||
{query ? 'No matching fields' : 'All fields are already configured'}
|
||||
</div>
|
||||
) : (
|
||||
filtered.map((field, i) => {
|
||||
const isEnterTarget = i === enterTargetIndex
|
||||
const isFocused = i === focusIndex || isEnterTarget
|
||||
return (
|
||||
<div
|
||||
key={field.path}
|
||||
style={{
|
||||
padding: '8px 12px', cursor: 'pointer',
|
||||
background: isFocused ? 'var(--color-bg-tertiary)' : 'transparent',
|
||||
borderBottom: '1px solid var(--color-border-subtle)',
|
||||
}}
|
||||
onMouseEnter={() => setFocusIndex(i)}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault()
|
||||
handleSelect(field)
|
||||
}}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
<span style={{
|
||||
fontSize: '0.625rem', padding: '1px 6px', borderRadius: 'var(--radius-sm)',
|
||||
background: `color-mix(in srgb, ${sectionColors[field.section] || 'var(--color-text-muted)'} 15%, transparent)`,
|
||||
color: sectionColors[field.section] || 'var(--color-text-muted)',
|
||||
fontWeight: 600, whiteSpace: 'nowrap',
|
||||
}}>
|
||||
{field.section}
|
||||
</span>
|
||||
<span style={{ fontSize: '0.8125rem', fontWeight: 500 }}>{field.label}</span>
|
||||
{isEnterTarget && (
|
||||
<span style={{ marginLeft: 'auto', color: 'var(--color-text-muted)', fontSize: '0.75rem' }}>↵</span>
|
||||
)}
|
||||
</div>
|
||||
{field.description && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2, marginLeft: 0 }}>
|
||||
{field.description}
|
||||
</div>
|
||||
)}
|
||||
<div style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)', marginTop: 1, fontFamily: 'monospace' }}>
|
||||
{field.path}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -27,6 +27,11 @@ export default function OperationsBar() {
|
||||
({op.error})
|
||||
</span>
|
||||
</>
|
||||
) : op.taskType === 'staging' ? (
|
||||
<>
|
||||
<i className="fas fa-cloud-arrow-up" style={{ marginRight: 'var(--spacing-xs)' }} />
|
||||
Staging model: {op.name}{op.nodeName ? ` → ${op.nodeName}` : ''}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{op.isDeletion ? 'Removing' : 'Installing'}{' '}
|
||||
|
||||
61
core/http/react-ui/src/components/TemplateSelector.jsx
Normal file
61
core/http/react-ui/src/components/TemplateSelector.jsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import MODEL_TEMPLATES from '../utils/modelTemplates'
|
||||
|
||||
export default function TemplateSelector({ onSelect }) {
|
||||
return (
|
||||
<div style={{ padding: '0 var(--spacing-lg) var(--spacing-lg)' }}>
|
||||
<p style={{ fontSize: '0.875rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-lg)' }}>
|
||||
Choose a template to get started. You can add or remove fields in the next step.
|
||||
</p>
|
||||
<div style={{
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(auto-fill, minmax(260px, 1fr))',
|
||||
gap: 'var(--spacing-md)',
|
||||
}}>
|
||||
{MODEL_TEMPLATES.map(t => (
|
||||
<button
|
||||
key={t.id}
|
||||
className="template-card"
|
||||
onClick={() => onSelect(t)}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)', width: '100%' }}>
|
||||
<i className={`fas ${t.icon}`} style={{ fontSize: '1.25rem', color: 'var(--color-primary)', width: 28, textAlign: 'center' }} />
|
||||
<span style={{ fontSize: '1rem', fontWeight: 600, color: 'var(--color-text-primary)' }}>{t.label}</span>
|
||||
</div>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', lineHeight: 1.5, margin: 0 }}>
|
||||
{t.description}
|
||||
</p>
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: '4px', marginTop: 'var(--spacing-xs)' }}>
|
||||
{Object.keys(t.fields).filter(k => k !== 'name').map(k => (
|
||||
<span key={k} className="badge" style={{
|
||||
fontSize: '0.6875rem', background: 'var(--color-bg-tertiary)',
|
||||
color: 'var(--color-text-muted)', padding: '2px 6px',
|
||||
}}>
|
||||
{k}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
<style>{`
|
||||
.template-card {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: var(--spacing-sm);
|
||||
padding: var(--spacing-lg);
|
||||
background: var(--color-bg-secondary);
|
||||
border: 1px solid var(--color-border-default);
|
||||
border-radius: var(--radius-lg);
|
||||
cursor: pointer;
|
||||
text-align: left;
|
||||
transition: all 150ms;
|
||||
}
|
||||
.template-card:hover {
|
||||
border-color: var(--color-primary);
|
||||
background: var(--color-primary-light);
|
||||
}
|
||||
`}</style>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
21
core/http/react-ui/src/hooks/useAgentChat.js
vendored
21
core/http/react-ui/src/hooks/useAgentChat.js
vendored
@@ -1,8 +1,8 @@
|
||||
import { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import { useState, useCallback, useEffect } from 'react'
|
||||
import { generateId } from '../utils/format'
|
||||
import { useDebouncedEffect } from './useDebounce'
|
||||
|
||||
const STORAGE_KEY_PREFIX = 'localai_agent_chats_'
|
||||
const SAVE_DEBOUNCE_MS = 500
|
||||
|
||||
function storageKey(agentName) {
|
||||
return STORAGE_KEY_PREFIX + agentName
|
||||
@@ -67,24 +67,9 @@ export function useAgentChat(agentName) {
|
||||
return conversations[0]?.id
|
||||
})
|
||||
|
||||
const saveTimerRef = useRef(null)
|
||||
|
||||
const activeConversation = conversations.find(c => c.id === activeId) || conversations[0]
|
||||
|
||||
// Debounced save
|
||||
const debouncedSave = useCallback(() => {
|
||||
if (saveTimerRef.current) clearTimeout(saveTimerRef.current)
|
||||
saveTimerRef.current = setTimeout(() => {
|
||||
saveConversations(agentName, conversations, activeId)
|
||||
}, SAVE_DEBOUNCE_MS)
|
||||
}, [agentName, conversations, activeId])
|
||||
|
||||
useEffect(() => {
|
||||
debouncedSave()
|
||||
return () => {
|
||||
if (saveTimerRef.current) clearTimeout(saveTimerRef.current)
|
||||
}
|
||||
}, [conversations, activeId, debouncedSave])
|
||||
useDebouncedEffect(() => saveConversations(agentName, conversations, activeId), [agentName, conversations, activeId])
|
||||
|
||||
// Save immediately on unmount
|
||||
useEffect(() => {
|
||||
|
||||
47
core/http/react-ui/src/hooks/useAutocomplete.js
vendored
Normal file
47
core/http/react-ui/src/hooks/useAutocomplete.js
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { modelsApi } from '../utils/api'
|
||||
|
||||
// Module-level cache so each provider is fetched once per page load
|
||||
const cache = {}
|
||||
|
||||
// Shared fetch-with-cache for use outside React hooks (e.g. CodeMirror completions)
|
||||
export async function fetchCachedAutocomplete(provider) {
|
||||
if (cache[provider]) return cache[provider].values
|
||||
try {
|
||||
const data = await modelsApi.getAutocomplete(provider)
|
||||
const vals = data?.values || []
|
||||
cache[provider] = { values: vals }
|
||||
return vals
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export function useAutocomplete(provider) {
|
||||
const [values, setValues] = useState(cache[provider]?.values || [])
|
||||
const [loading, setLoading] = useState(!cache[provider])
|
||||
|
||||
useEffect(() => {
|
||||
if (!provider) {
|
||||
setValues([])
|
||||
setLoading(false)
|
||||
return
|
||||
}
|
||||
if (cache[provider]) {
|
||||
setValues(cache[provider].values)
|
||||
setLoading(false)
|
||||
return
|
||||
}
|
||||
setLoading(true)
|
||||
modelsApi.getAutocomplete(provider)
|
||||
.then(data => {
|
||||
const vals = data?.values || []
|
||||
cache[provider] = { values: vals }
|
||||
setValues(vals)
|
||||
})
|
||||
.catch(() => setValues([]))
|
||||
.finally(() => setLoading(false))
|
||||
}, [provider])
|
||||
|
||||
return { values, loading }
|
||||
}
|
||||
17
core/http/react-ui/src/hooks/useChat.js
vendored
17
core/http/react-ui/src/hooks/useChat.js
vendored
@@ -1,6 +1,7 @@
|
||||
import { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import { useState, useCallback, useRef } from 'react'
|
||||
import { API_CONFIG } from '../utils/config'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import { useDebouncedEffect } from './useDebounce'
|
||||
|
||||
const thinkingTagRegex = /<thinking>([\s\S]*?)<\/thinking>|<think>([\s\S]*?)<\/think>|<\|channel>thought([\s\S]*?)<channel\|>/g
|
||||
const openThinkTagRegex = /<thinking>|<think>|<\|channel>thought/
|
||||
@@ -33,7 +34,6 @@ function extractThinking(text) {
|
||||
import { generateId } from '../utils/format'
|
||||
|
||||
const CHATS_STORAGE_KEY = 'localai_chats_data'
|
||||
const SAVE_DEBOUNCE_MS = 500
|
||||
|
||||
function loadChats() {
|
||||
try {
|
||||
@@ -123,24 +123,13 @@ export function useChat(initialModel = '') {
|
||||
const [tokensPerSecond, setTokensPerSecond] = useState(null)
|
||||
const [maxTokensPerSecond, setMaxTokensPerSecond] = useState(null)
|
||||
const abortControllerRef = useRef(null)
|
||||
const saveTimerRef = useRef(null)
|
||||
const startTimeRef = useRef(null)
|
||||
const tokenCountRef = useRef(0)
|
||||
const maxTpsRef = useRef(0)
|
||||
|
||||
const activeChat = chats.find(c => c.id === activeChatId) || chats[0]
|
||||
|
||||
// Debounced save
|
||||
const debouncedSave = useCallback(() => {
|
||||
if (saveTimerRef.current) clearTimeout(saveTimerRef.current)
|
||||
saveTimerRef.current = setTimeout(() => {
|
||||
saveChats(chats, activeChatId)
|
||||
}, SAVE_DEBOUNCE_MS)
|
||||
}, [chats, activeChatId])
|
||||
|
||||
useEffect(() => {
|
||||
debouncedSave()
|
||||
}, [chats, activeChatId, debouncedSave])
|
||||
useDebouncedEffect(() => saveChats(chats, activeChatId), [chats, activeChatId])
|
||||
|
||||
const addChat = useCallback((model = '', systemPrompt = '', mcpMode = false) => {
|
||||
const chat = createNewChat(model, systemPrompt, mcpMode)
|
||||
|
||||
79
core/http/react-ui/src/hooks/useCodeMirror.js
vendored
Normal file
79
core/http/react-ui/src/hooks/useCodeMirror.js
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
import { useRef, useEffect } from 'react'
|
||||
import { EditorView } from '@codemirror/view'
|
||||
import { EditorState, Compartment } from '@codemirror/state'
|
||||
|
||||
export function useCodeMirror({ containerRef, value, onChange, extensions = [], dynamicExtensions = {} }) {
|
||||
const viewRef = useRef(null)
|
||||
const onChangeRef = useRef(onChange)
|
||||
const isExternalUpdate = useRef(false)
|
||||
const compartmentsRef = useRef({})
|
||||
|
||||
onChangeRef.current = onChange
|
||||
|
||||
// Create editor on mount (only depends on container and static extensions)
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return
|
||||
|
||||
const listener = EditorView.updateListener.of(update => {
|
||||
if (update.docChanged && !isExternalUpdate.current) {
|
||||
onChangeRef.current(update.state.doc.toString())
|
||||
}
|
||||
})
|
||||
|
||||
// Create compartments for each dynamic extension key
|
||||
const compartments = {}
|
||||
const compartmentExts = []
|
||||
for (const [key, ext] of Object.entries(dynamicExtensions)) {
|
||||
compartments[key] = new Compartment()
|
||||
compartmentExts.push(compartments[key].of(ext))
|
||||
}
|
||||
compartmentsRef.current = compartments
|
||||
|
||||
const state = EditorState.create({
|
||||
doc: value,
|
||||
extensions: [...extensions, ...compartmentExts, listener],
|
||||
})
|
||||
|
||||
const view = new EditorView({ state, parent: containerRef.current })
|
||||
viewRef.current = view
|
||||
|
||||
return () => {
|
||||
view.destroy()
|
||||
viewRef.current = null
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [containerRef, extensions])
|
||||
|
||||
// Reconfigure dynamic extensions without recreating the editor
|
||||
useEffect(() => {
|
||||
const view = viewRef.current
|
||||
if (!view) return
|
||||
const effects = []
|
||||
for (const [key, ext] of Object.entries(dynamicExtensions)) {
|
||||
const compartment = compartmentsRef.current[key]
|
||||
if (compartment) {
|
||||
effects.push(compartment.reconfigure(ext))
|
||||
}
|
||||
}
|
||||
if (effects.length > 0) {
|
||||
view.dispatch({ effects })
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [dynamicExtensions])
|
||||
|
||||
// Sync external value changes into CM6
|
||||
useEffect(() => {
|
||||
const view = viewRef.current
|
||||
if (!view) return
|
||||
const current = view.state.doc.toString()
|
||||
if (value !== current) {
|
||||
isExternalUpdate.current = true
|
||||
view.dispatch({
|
||||
changes: { from: 0, to: current.length, insert: value },
|
||||
})
|
||||
isExternalUpdate.current = false
|
||||
}
|
||||
}, [value])
|
||||
|
||||
return { view: viewRef }
|
||||
}
|
||||
22
core/http/react-ui/src/hooks/useConfigMetadata.js
vendored
Normal file
22
core/http/react-ui/src/hooks/useConfigMetadata.js
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { modelsApi } from '../utils/api'
|
||||
|
||||
export function useConfigMetadata() {
|
||||
const [metadata, setMetadata] = useState(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState(null)
|
||||
|
||||
useEffect(() => {
|
||||
modelsApi.getConfigMetadata('all')
|
||||
.then(data => setMetadata(data))
|
||||
.catch(err => setError(err.message))
|
||||
.finally(() => setLoading(false))
|
||||
}, [])
|
||||
|
||||
return {
|
||||
sections: metadata?.sections || [],
|
||||
fields: metadata?.fields || [],
|
||||
loading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
40
core/http/react-ui/src/hooks/useDebounce.js
vendored
Normal file
40
core/http/react-ui/src/hooks/useDebounce.js
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
import { useRef, useEffect, useCallback } from 'react'
|
||||
|
||||
/**
|
||||
* Returns a debounced version of the callback. Always calls the latest
|
||||
* version of fn (via ref), so callers don't need to memoize it.
|
||||
* Timer is cleaned up on unmount.
|
||||
*/
|
||||
export function useDebouncedCallback(fn, delay = 500) {
|
||||
const timerRef = useRef(null)
|
||||
const fnRef = useRef(fn)
|
||||
fnRef.current = fn
|
||||
|
||||
useEffect(() => () => {
|
||||
if (timerRef.current) clearTimeout(timerRef.current)
|
||||
}, [])
|
||||
|
||||
return useCallback((...args) => {
|
||||
if (timerRef.current) clearTimeout(timerRef.current)
|
||||
timerRef.current = setTimeout(() => fnRef.current(...args), delay)
|
||||
}, [delay])
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a debounced effect: when deps change, waits `delay` ms before
|
||||
* calling fn. Resets the timer on each deps change. Cleans up on unmount.
|
||||
*/
|
||||
export function useDebouncedEffect(fn, deps, delay = 500) {
|
||||
const timerRef = useRef(null)
|
||||
const fnRef = useRef(fn)
|
||||
fnRef.current = fn
|
||||
|
||||
useEffect(() => {
|
||||
if (timerRef.current) clearTimeout(timerRef.current)
|
||||
timerRef.current = setTimeout(() => fnRef.current(), delay)
|
||||
return () => {
|
||||
if (timerRef.current) clearTimeout(timerRef.current)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, deps)
|
||||
}
|
||||
53
core/http/react-ui/src/hooks/useVramEstimate.js
vendored
Normal file
53
core/http/react-ui/src/hooks/useVramEstimate.js
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
import { useState, useEffect, useRef, useMemo } from 'react'
|
||||
import { modelsApi } from '../utils/api'
|
||||
|
||||
const DEBOUNCE_MS = 500
|
||||
|
||||
export function useVramEstimate({ model, contextSize, gpuLayers }) {
|
||||
const [vramDisplay, setVramDisplay] = useState(null)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const debounceRef = useRef(null)
|
||||
const abortRef = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (!model || contextSize === undefined) {
|
||||
setVramDisplay(null)
|
||||
setLoading(false)
|
||||
return
|
||||
}
|
||||
|
||||
if (debounceRef.current) clearTimeout(debounceRef.current)
|
||||
if (abortRef.current) abortRef.current.abort()
|
||||
|
||||
debounceRef.current = setTimeout(async () => {
|
||||
const controller = new AbortController()
|
||||
abortRef.current = controller
|
||||
setLoading(true)
|
||||
|
||||
try {
|
||||
const body = { model }
|
||||
if (contextSize != null && contextSize !== '') body.context_size = Number(contextSize)
|
||||
if (gpuLayers != null && gpuLayers !== '') body.gpu_layers = Number(gpuLayers)
|
||||
|
||||
const data = await modelsApi.estimateVram(body, { signal: controller.signal })
|
||||
|
||||
if (!controller.signal.aborted) {
|
||||
setVramDisplay(data?.vramDisplay || null)
|
||||
setLoading(false)
|
||||
}
|
||||
} catch {
|
||||
if (!controller.signal.aborted) {
|
||||
setVramDisplay(null)
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
}, DEBOUNCE_MS)
|
||||
|
||||
return () => {
|
||||
if (debounceRef.current) clearTimeout(debounceRef.current)
|
||||
if (abortRef.current) abortRef.current.abort()
|
||||
}
|
||||
}, [model, contextSize, gpuLayers])
|
||||
|
||||
return useMemo(() => ({ vramDisplay, loading }), [vramDisplay, loading])
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { useNavigate, useOutletContext } from 'react-router-dom'
|
||||
import { backendsApi } from '../utils/api'
|
||||
import { useDebouncedCallback } from '../hooks/useDebounce'
|
||||
import React from 'react'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import { renderMarkdown } from '../utils/markdown'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import Toggle from '../components/Toggle'
|
||||
|
||||
export default function Backends() {
|
||||
const { addToast } = useOutletContext()
|
||||
@@ -24,9 +26,12 @@ export default function Backends() {
|
||||
const [manualAlias, setManualAlias] = useState('')
|
||||
const [expandedRow, setExpandedRow] = useState(null)
|
||||
const [confirmDialog, setConfirmDialog] = useState(null)
|
||||
const debounceRef = useRef(null)
|
||||
|
||||
const [allBackends, setAllBackends] = useState([])
|
||||
const [upgrades, setUpgrades] = useState({})
|
||||
const [upgradingAll, setUpgradingAll] = useState(false)
|
||||
const [showAllBackends, setShowAllBackends] = useState(false)
|
||||
const [showDevelopment, setShowDevelopment] = useState(false)
|
||||
const [preferDevLoaded, setPreferDevLoaded] = useState(false)
|
||||
|
||||
const fetchBackends = useCallback(async () => {
|
||||
try {
|
||||
@@ -37,6 +42,11 @@ export default function Backends() {
|
||||
const list = Array.isArray(data?.backends) ? data.backends : Array.isArray(data) ? data : []
|
||||
setAllBackends(list)
|
||||
setInstalledCount(list.filter(b => b.installed).length)
|
||||
// On first load, use server preference for development toggle
|
||||
if (!preferDevLoaded && data?.preferDevelopmentBackends) {
|
||||
setShowDevelopment(true)
|
||||
setPreferDevLoaded(true)
|
||||
}
|
||||
} catch (err) {
|
||||
addToast(`Failed to load backends: ${err.message}`, 'error')
|
||||
} finally {
|
||||
@@ -53,28 +63,52 @@ export default function Backends() {
|
||||
if (!loading) fetchBackends()
|
||||
}, [operations.length])
|
||||
|
||||
// Client-side filtering by tag
|
||||
const filteredBackends = filter
|
||||
? allBackends.filter(b => {
|
||||
// Fetch available upgrades
|
||||
useEffect(() => {
|
||||
backendsApi.checkUpgrades()
|
||||
.then(data => setUpgrades(data || {}))
|
||||
.catch(() => {})
|
||||
}, [operations.length])
|
||||
|
||||
// Client-side filtering by meta/development toggles and tag
|
||||
const filteredBackends = (() => {
|
||||
let result = allBackends
|
||||
|
||||
// Show only meta backends unless "Show all" is toggled
|
||||
if (!showAllBackends) {
|
||||
result = result.filter(b => b.isMeta)
|
||||
}
|
||||
|
||||
// Hide development backends unless toggled on
|
||||
if (!showDevelopment) {
|
||||
result = result.filter(b => !b.isDevelopment)
|
||||
}
|
||||
|
||||
// Apply tag filter
|
||||
if (filter) {
|
||||
result = result.filter(b => {
|
||||
const tags = (b.tags || []).map(t => t.toLowerCase())
|
||||
const name = (b.name || '').toLowerCase()
|
||||
const desc = (b.description || '').toLowerCase()
|
||||
const f = filter.toLowerCase()
|
||||
// Match against tags, or name/description containing the filter keyword
|
||||
return tags.some(t => t.includes(f)) || name.includes(f) || desc.includes(f)
|
||||
})
|
||||
: allBackends
|
||||
}
|
||||
|
||||
return result
|
||||
})()
|
||||
|
||||
// Client-side pagination
|
||||
const ITEMS_PER_PAGE = 21
|
||||
const totalPages = Math.max(1, Math.ceil(filteredBackends.length / ITEMS_PER_PAGE))
|
||||
const backends = filteredBackends.slice((page - 1) * ITEMS_PER_PAGE, page * ITEMS_PER_PAGE)
|
||||
|
||||
const debouncedFetch = useDebouncedCallback(() => fetchBackends())
|
||||
|
||||
const handleSearch = (value) => {
|
||||
setSearch(value)
|
||||
setPage(1)
|
||||
if (debounceRef.current) clearTimeout(debounceRef.current)
|
||||
debounceRef.current = setTimeout(() => fetchBackends(), 500)
|
||||
debouncedFetch()
|
||||
}
|
||||
|
||||
const handleSort = (col) => {
|
||||
@@ -114,6 +148,31 @@ export default function Backends() {
|
||||
})
|
||||
}
|
||||
|
||||
const handleUpgrade = async (id) => {
|
||||
try {
|
||||
await backendsApi.upgrade(id)
|
||||
addToast(`Upgrading ${id}...`, 'info')
|
||||
} catch (err) {
|
||||
addToast(`Upgrade failed: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpgradeAll = async () => {
|
||||
const names = Object.keys(upgrades)
|
||||
if (names.length === 0) return
|
||||
setUpgradingAll(true)
|
||||
try {
|
||||
for (const name of names) {
|
||||
await backendsApi.upgrade(name)
|
||||
}
|
||||
addToast(`Upgrading ${names.length} backend${names.length > 1 ? 's' : ''}...`, 'info')
|
||||
} catch (err) {
|
||||
addToast(`Upgrade failed: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setUpgradingAll(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleManualInstall = async (e) => {
|
||||
e.preventDefault()
|
||||
if (!manualUri.trim()) { addToast('Please enter a URI', 'warning'); return }
|
||||
@@ -137,6 +196,9 @@ export default function Backends() {
|
||||
return operations.find(op => op.name === backend.name || op.name === backend.id) || null
|
||||
}
|
||||
|
||||
const handleToggleAllBackends = () => { setShowAllBackends(v => !v); setPage(1) }
|
||||
const handleToggleDev = () => { setShowDevelopment(v => !v); setPage(1) }
|
||||
|
||||
const FILTERS = [
|
||||
{ key: '', label: 'All', icon: 'fa-layer-group' },
|
||||
{ key: 'llm', label: 'LLM', icon: 'fa-brain' },
|
||||
@@ -179,6 +241,14 @@ export default function Backends() {
|
||||
<div style={{ color: 'var(--color-text-muted)' }}>Installed</div>
|
||||
</a>
|
||||
</div>
|
||||
{Object.keys(upgrades).length > 0 && (
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<div style={{ fontSize: '1.25rem', fontWeight: 700, color: 'var(--color-warning)' }}>
|
||||
{Object.keys(upgrades).length}
|
||||
</div>
|
||||
<div style={{ color: 'var(--color-text-muted)' }}>Updates</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<a className="btn btn-secondary btn-sm" href="https://localai.io/docs/getting-started/manual/" target="_blank" rel="noopener noreferrer">
|
||||
<i className="fas fa-book" /> Docs
|
||||
@@ -186,6 +256,33 @@ export default function Backends() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Upgrade Banner */}
|
||||
{Object.keys(upgrades).length > 0 && (
|
||||
<div className="card" style={{
|
||||
marginBottom: 'var(--spacing-md)',
|
||||
display: 'flex', alignItems: 'center', justifyContent: 'space-between',
|
||||
padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
background: 'var(--color-warning-bg, #fef3cd)',
|
||||
border: '1px solid var(--color-warning, #ffc107)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
}}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-arrow-up" style={{ color: 'var(--color-warning, #856404)' }} />
|
||||
<span style={{ color: 'var(--color-warning, #856404)', fontWeight: 500, fontSize: '0.875rem' }}>
|
||||
{Object.keys(upgrades).length} backend{Object.keys(upgrades).length > 1 ? 's have' : ' has'} updates available
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
className="btn btn-primary btn-sm"
|
||||
onClick={handleUpgradeAll}
|
||||
disabled={upgradingAll}
|
||||
>
|
||||
<i className={`fas ${upgradingAll ? 'fa-spinner fa-spin' : 'fa-arrow-up'}`} style={{ marginRight: 4 }} />
|
||||
Upgrade All
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Manual Install */}
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => setShowManualInstall(!showManualInstall)}>
|
||||
@@ -227,17 +324,30 @@ export default function Backends() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="filter-bar" style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
{FILTERS.map(f => (
|
||||
<button
|
||||
key={f.key}
|
||||
className={`filter-btn ${filter === f.key ? 'active' : ''}`}
|
||||
onClick={() => { setFilter(f.key); setPage(1) }}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)', flexWrap: 'wrap' }}>
|
||||
<div className="filter-bar" style={{ margin: 0, flex: 1 }}>
|
||||
{FILTERS.map(f => (
|
||||
<button
|
||||
key={f.key}
|
||||
className={`filter-btn ${filter === f.key ? 'active' : ''}`}
|
||||
onClick={() => { setFilter(f.key); setPage(1) }}
|
||||
>
|
||||
<i className={`fas ${f.icon}`} style={{ marginRight: 4 }} />
|
||||
{f.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)', alignItems: 'center', borderLeft: '1px solid var(--color-border-subtle)', paddingLeft: 'var(--spacing-md)' }}>
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)', fontSize: '0.75rem', color: 'var(--color-text-secondary)', cursor: 'pointer', userSelect: 'none', whiteSpace: 'nowrap' }}>
|
||||
<Toggle checked={showAllBackends} onChange={handleToggleAllBackends} />
|
||||
Show all
|
||||
</label>
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)', fontSize: '0.75rem', color: 'var(--color-text-secondary)', cursor: 'pointer', userSelect: 'none', whiteSpace: 'nowrap' }}>
|
||||
<Toggle checked={showDevelopment} onChange={handleToggleDev} />
|
||||
Development
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Table */}
|
||||
@@ -300,6 +410,11 @@ export default function Backends() {
|
||||
{/* Name */}
|
||||
<td>
|
||||
<span style={{ fontWeight: 500 }}>{b.name || b.id}</span>
|
||||
{b.version && (
|
||||
<span className="badge" style={{ fontSize: '0.625rem', marginLeft: 4, background: 'var(--color-bg-tertiary)', color: 'var(--color-text-secondary)' }}>
|
||||
v{b.version}
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
|
||||
{/* Description */}
|
||||
@@ -346,9 +461,17 @@ export default function Backends() {
|
||||
</span>
|
||||
</div>
|
||||
) : b.installed ? (
|
||||
<span className="badge badge-success">
|
||||
<i className="fas fa-check" style={{ fontSize: '0.5rem', marginRight: 2 }} /> Installed
|
||||
</span>
|
||||
<div style={{ display: 'flex', gap: 4, alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<span className="badge badge-success">
|
||||
<i className="fas fa-check" style={{ fontSize: '0.5rem', marginRight: 2 }} /> Installed
|
||||
</span>
|
||||
{upgrades[b.name] && (
|
||||
<span className="badge" style={{ fontSize: '0.625rem', background: '#fef3cd', color: '#856404' }}>
|
||||
<i className="fas fa-arrow-up" style={{ fontSize: '0.5rem', marginRight: 2 }} />
|
||||
{upgrades[b.name].available_version ? `v${upgrades[b.name].available_version}` : 'Update'}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<span className="badge" style={{ background: 'var(--color-bg-tertiary)', color: 'var(--color-text-muted)' }}>
|
||||
<i className="fas fa-circle" style={{ fontSize: '0.5rem', marginRight: 2 }} /> Not Installed
|
||||
@@ -361,9 +484,15 @@ export default function Backends() {
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-xs)', justifyContent: 'flex-end' }} onClick={e => e.stopPropagation()}>
|
||||
{b.installed ? (
|
||||
<>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => handleInstall(b.name || b.id)} title="Reinstall" disabled={isProcessing}>
|
||||
<i className={`fas ${isProcessing ? 'fa-spinner fa-spin' : 'fa-rotate'}`} />
|
||||
</button>
|
||||
{upgrades[b.name] ? (
|
||||
<button className="btn btn-primary btn-sm" onClick={() => handleUpgrade(b.name || b.id)} title={`Upgrade to ${upgrades[b.name]?.available_version ? 'v' + upgrades[b.name].available_version : 'latest'}`} disabled={isProcessing}>
|
||||
<i className={`fas ${isProcessing ? 'fa-spinner fa-spin' : 'fa-arrow-up'}`} />
|
||||
</button>
|
||||
) : (
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => handleInstall(b.name || b.id)} title="Reinstall" disabled={isProcessing}>
|
||||
<i className={`fas ${isProcessing ? 'fa-spinner fa-spin' : 'fa-rotate'}`} />
|
||||
</button>
|
||||
)}
|
||||
<button className="btn btn-danger btn-sm" onClick={() => handleDelete(b.name || b.id)} title="Delete" disabled={isProcessing}>
|
||||
<i className="fas fa-trash" />
|
||||
</button>
|
||||
|
||||
@@ -13,6 +13,7 @@ import UnifiedMCPDropdown from '../components/UnifiedMCPDropdown'
|
||||
import { loadClientMCPServers } from '../utils/mcpClientStorage'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
import { relativeTime } from '../utils/format'
|
||||
|
||||
function getLastMessagePreview(chat) {
|
||||
@@ -277,6 +278,7 @@ export default function Chat() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const { isAdmin } = useAuth()
|
||||
const { operations } = useOperations()
|
||||
const {
|
||||
chats, activeChat, activeChatId, isStreaming, streamingChatId, streamingContent,
|
||||
streamingReasoning, streamingToolCalls, tokensPerSecond, maxTokensPerSecond,
|
||||
@@ -284,6 +286,12 @@ export default function Chat() {
|
||||
sendMessage, stopGeneration, clearHistory, getContextUsagePercent, addMessage,
|
||||
} = useChat(urlModel || '')
|
||||
|
||||
// Detect active staging operation for the current chat's model
|
||||
const stagingOp = useMemo(() => {
|
||||
if (!isStreaming || !activeChat?.model) return null
|
||||
return operations.find(op => op.taskType === 'staging' && op.name === activeChat.model) || null
|
||||
}, [operations, isStreaming, activeChat?.model])
|
||||
|
||||
const [input, setInput] = useState('')
|
||||
const [files, setFiles] = useState([])
|
||||
const [showSettings, setShowSettings] = useState(false)
|
||||
@@ -1187,9 +1195,28 @@ export default function Chat() {
|
||||
</div>
|
||||
<div className="chat-message-bubble">
|
||||
<div className="chat-message-content chat-thinking-indicator">
|
||||
<span className="chat-thinking-dots">
|
||||
<span /><span /><span />
|
||||
</span>
|
||||
{stagingOp ? (
|
||||
<div className="chat-staging-progress">
|
||||
<div className="chat-staging-label">
|
||||
<i className="fas fa-cloud-arrow-up" /> Transferring model{stagingOp.nodeName ? ` to ${stagingOp.nodeName}` : ''}...
|
||||
</div>
|
||||
{stagingOp.progress > 0 && (
|
||||
<div className="chat-staging-detail">
|
||||
<div className="chat-staging-bar-container">
|
||||
<div className="chat-staging-bar" style={{ width: `${stagingOp.progress}%` }} />
|
||||
</div>
|
||||
<span className="chat-staging-pct">{Math.round(stagingOp.progress)}%</span>
|
||||
</div>
|
||||
)}
|
||||
{stagingOp.message && (
|
||||
<div className="chat-staging-file">{stagingOp.message}</div>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<span className="chat-thinking-dots">
|
||||
<span /><span /><span />
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -204,9 +204,6 @@ export default function ImportModel() {
|
||||
</p>
|
||||
</div>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', flexWrap: 'wrap' }}>
|
||||
<button className="btn btn-secondary" onClick={() => navigate('/app/pipeline-editor')}>
|
||||
<i className="fas fa-diagram-project" /> Create Pipeline Model
|
||||
</button>
|
||||
<button className="btn btn-secondary" onClick={() => setIsAdvancedMode(!isAdvancedMode)}>
|
||||
<i className={`fas ${isAdvancedMode ? 'fa-magic' : 'fa-code'}`} />
|
||||
{isAdvancedMode ? ' Simple Mode' : ' Advanced Mode'}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user