Compare commits

..

2 Commits

Author SHA1 Message Date
Ettore Di Giacinto
478d2adfb7 Trigger CI
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-08-24 16:48:33 +02:00
Ettore Di Giacinto
909fdd1b0e feat(transformers): add support for CPU and MPS
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-08-24 16:47:51 +02:00
138 changed files with 1006 additions and 7426 deletions

View File

@@ -6,10 +6,6 @@ models
backends
examples/chatbot-ui/models
backend/go/image/stablediffusion-ggml/build/
backend/go/*/build
backend/go/*/.cache
backend/go/*/sources
backend/go/*/package
examples/rwkv/models
examples/**/models
Dockerfile*

View File

@@ -2,6 +2,7 @@
name: 'build backend container images'
on:
pull_request:
push:
branches:
- master
@@ -63,6 +64,18 @@ jobs:
backend: "llama-cpp"
dockerfile: "./backend/Dockerfile.llama-cpp"
context: "./"
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-transformers'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'true'
backend: "transformers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "11"
cuda-minor-version: "7"
@@ -111,18 +124,6 @@ jobs:
backend: "diffusers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-chatterbox'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'true'
backend: "chatterbox"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
# CUDA 11 additional backends
- build-type: 'cublas'
cuda-major-version: "11"
@@ -242,7 +243,7 @@ jobs:
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'false'
backend: "diffusers"
backend: "diffusers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
# CUDA 12 additional backends
@@ -775,7 +776,7 @@ jobs:
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-rocm-hipblas-whisper'
tag-suffix: '-gpu-hipblas-whisper'
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
runs-on: 'ubuntu-latest'
skip-drivers: 'false'
@@ -969,41 +970,54 @@ jobs:
backend: "kitten-tts"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
backend-jobs-darwin:
transformers-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
strategy:
matrix:
include:
- backend: "diffusers"
tag-suffix: "-metal-darwin-arm64-diffusers"
build-type: "mps"
- backend: "mlx"
tag-suffix: "-metal-darwin-arm64-mlx"
build-type: "mps"
- backend: "chatterbox"
tag-suffix: "-metal-darwin-arm64-chatterbox"
build-type: "mps"
- backend: "mlx-vlm"
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
build-type: "mps"
- backend: "mlx-audio"
tag-suffix: "-metal-darwin-arm64-mlx-audio"
build-type: "mps"
- backend: "stablediffusion-ggml"
tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
build-type: "metal"
lang: "go"
- backend: "whisper"
tag-suffix: "-metal-darwin-arm64-whisper"
build-type: "metal"
lang: "go"
with:
backend: ${{ matrix.backend }}
build-type: ${{ matrix.build-type }}
backend: "transformers"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: ${{ matrix.tag-suffix }}
lang: ${{ matrix.lang || 'python' }}
use-pip: ${{ matrix.backend == 'diffusers' }}
tag-suffix: "-metal-darwin-arm64-transformers"
use-pip: true
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
diffusers-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "diffusers"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-diffusers"
use-pip: true
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
mlx-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "mlx"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-mlx"
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
mlx-vlm-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "mlx-vlm"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}

View File

@@ -16,10 +16,6 @@ on:
description: 'Use pip to install dependencies'
default: false
type: boolean
lang:
description: 'Programming language (e.g. go)'
default: 'python'
type: string
go-version:
description: 'Go version to use'
default: '1.24.x'
@@ -53,26 +49,26 @@ jobs:
uses: actions/checkout@v5
with:
submodules: true
- name: Setup Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: false
# You can test your matrix by printing the current Go version
- name: Display Go version
run: go version
- name: Dependencies
run: |
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
- name: Build ${{ inputs.backend }}-darwin
run: |
make protogen-go
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-python-backend
- name: Upload ${{ inputs.backend }}.tar
uses: actions/upload-artifact@v4
with:
@@ -89,20 +85,20 @@ jobs:
with:
name: ${{ inputs.backend }}-tar
path: .
- name: Install crane
run: |
curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz
sudo mv crane /usr/local/bin/
- name: Log in to DockerHub
run: |
echo "${{ secrets.dockerPassword }}" | crane auth login docker.io -u "${{ secrets.dockerUsername }}" --password-stdin
- name: Log in to quay.io
run: |
echo "${{ secrets.quayPassword }}" | crane auth login quay.io -u "${{ secrets.quayUsername }}" --password-stdin
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -116,7 +112,7 @@ jobs:
flavor: |
latest=auto
suffix=${{ inputs.tag-suffix }},onlatest=true
- name: Docker meta
id: quaymeta
uses: docker/metadata-action@v5
@@ -130,13 +126,13 @@ jobs:
flavor: |
latest=auto
suffix=${{ inputs.tag-suffix }},onlatest=true
- name: Push Docker image (DockerHub)
run: |
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do
crane push ${{ inputs.backend }}.tar $tag
done
- name: Push Docker image (Quay)
run: |
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do

View File

@@ -12,9 +12,7 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }}
has-backends: ${{ steps.set-matrix.outputs.has-backends }}
has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
steps:
- name: Checkout repository
uses: actions/checkout@v5
@@ -58,21 +56,3 @@ jobs:
strategy:
fail-fast: true
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
backend-jobs-darwin:
needs: generate-matrix
uses: ./.github/workflows/backend_build_darwin.yml
if: needs.generate-matrix.outputs.has-backends-darwin == 'true'
with:
backend: ${{ matrix.backend }}
build-type: ${{ matrix.build-type }}
go-version: "1.24.x"
tag-suffix: ${{ matrix.tag-suffix }}
lang: ${{ matrix.lang || 'python' }}
use-pip: ${{ matrix.backend == 'diffusers' }}
runs-on: "macOS-14"
secrets:
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
strategy:
fail-fast: true
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix-darwin) }}

View File

@@ -21,47 +21,3 @@ jobs:
- name: Run GoReleaser
run: |
make dev-dist
launcher-build-darwin:
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for macOS ARM64
run: |
make build-launcher-darwin
ls -liah dist
- name: Upload macOS launcher artifacts
uses: actions/upload-artifact@v4
with:
name: launcher-macos
path: dist/
retention-days: 30
launcher-build-linux:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for Linux
run: |
sudo apt-get update
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
make build-launcher-linux
- name: Upload Linux launcher artifacts
uses: actions/upload-artifact@v4
with:
name: launcher-linux
path: local-ai-launcher-linux.tar.xz
retention-days: 30

View File

@@ -9,4 +9,4 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@v5

View File

@@ -6,8 +6,7 @@ permissions:
contents: write
pull-requests: write
packages: read
issues: write # for Homebrew/actions/post-comment
actions: write # to dispatch publish workflow
jobs:
dependabot:
runs-on: ubuntu-latest

View File

@@ -23,42 +23,4 @@ jobs:
version: v2.11.0
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
launcher-build-darwin:
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for macOS ARM64
run: |
make build-launcher-darwin
- name: Upload DMG to Release
uses: softprops/action-gh-release@v2
with:
files: ./dist/LocalAI.dmg
launcher-build-linux:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for Linux
run: |
sudo apt-get update
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
make build-launcher-linux
- name: Upload Linux launcher artifacts
uses: softprops/action-gh-release@v2
with:
files: ./local-ai-launcher-linux.tar.xz
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -10,7 +10,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v9
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9
with:
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'

2
.gitignore vendored
View File

@@ -24,7 +24,7 @@ go-bert
# LocalAI build binary
LocalAI
/local-ai
local-ai
# prevent above rules from omitting the helm chart
!charts/*
# prevent above rules from omitting the api/localai folder

View File

@@ -8,7 +8,7 @@ source:
enabled: true
name_template: '{{ .ProjectName }}-{{ .Tag }}-source'
builds:
- main: ./cmd/local-ai
-
env:
- CGO_ENABLED=0
ldflags:

View File

@@ -100,10 +100,6 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
ldconfig \
; fi
RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
; fi
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
# Cuda

View File

@@ -2,7 +2,6 @@ GOCMD=go
GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
LAUNCHER_BINARY_NAME=local-ai-launcher
GORELEASER?=
@@ -91,17 +90,7 @@ build: protogen-go install-go-tools ## Build the project
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
$(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET})
rm -rf $(BINARY_NAME) || true
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai
build-launcher: ## Build the launcher application
$(info ${GREEN}I local-ai launcher build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
rm -rf $(LAUNCHER_BINARY_NAME) || true
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(LAUNCHER_BINARY_NAME) ./cmd/launcher
build-all: build build-launcher ## Build both server and launcher
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
dev-dist:
$(GORELEASER) build --snapshot --clean
@@ -117,8 +106,8 @@ run: ## run local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
test-models/testmodel.ggml:
mkdir -p test-models
mkdir -p test-dir
mkdir test-models
mkdir test-dir
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
@@ -369,9 +358,6 @@ backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
backends/kokoro: docker-build-kokoro docker-save-kokoro build
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
backends/chatterbox: docker-build-chatterbox docker-save-chatterbox build
./local-ai backends install "ocifile://$(abspath ./backend-images/chatterbox.tar)"
backends/llama-cpp-darwin: build
bash ./scripts/build/llama-cpp-darwin.sh
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
@@ -379,9 +365,6 @@ backends/llama-cpp-darwin: build
build-darwin-python-backend: build
bash ./scripts/build/python-darwin.sh
build-darwin-go-backend: build
bash ./scripts/build/golang-darwin.sh
backends/mlx:
BACKEND=mlx $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)"
@@ -394,14 +377,6 @@ backends/mlx-vlm:
BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
backends/mlx-audio:
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
backends/stablediffusion-ggml-darwin:
BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
backend-images:
mkdir -p backend-images
@@ -496,7 +471,7 @@ docker-build-bark:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
docker-build-chatterbox:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox ./backend
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
docker-build-exllama2:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
@@ -532,19 +507,3 @@ docs-clean:
.PHONY: docs
docs: docs/static/gallery.html
cd docs && hugo serve
########################################################
## Platform-specific builds
########################################################
## fyne cross-platform build
build-launcher-darwin: build-launcher
go run github.com/tiagomelo/macos-dmg-creator/cmd/createdmg@latest \
--appName "LocalAI" \
--appBinaryPath "$(LAUNCHER_BINARY_NAME)" \
--bundleIdentifier "com.localai.launcher" \
--iconPath "core/http/static/logo.png" \
--outputDir "dist/"
build-launcher-linux:
cd cmd/launcher && go run fyne.io/tools/cmd/fyne@latest package -os linux -icon ../../core/http/static/logo.png --executable $(LAUNCHER_BINARY_NAME)-linux && mv launcher.tar.xz ../../$(LAUNCHER_BINARY_NAME)-linux.tar.xz

View File

@@ -43,7 +43,7 @@
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
>
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🥽 Demo](https://demo.localai.io) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot)
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
@@ -110,12 +110,6 @@ curl https://localai.io/install.sh | sh
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
### macOS Download:
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
</a>
Or run with docker:
### CPU only image:
@@ -239,60 +233,6 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
- 🔊 Voice activity detection (Silero-VAD support)
- 🌍 Integrated WebUI!
## 🧩 Supported Backends & Acceleration
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
### Text Generation & Language Models
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **llama.cpp** | LLM inference in C/C++ | CUDA 11/12, ROCm, Intel SYCL, Vulkan, Metal, CPU |
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12, ROCm, Intel |
| **transformers** | HuggingFace transformers framework | CUDA 11/12, ROCm, Intel, CPU |
| **exllama2** | GPTQ inference library | CUDA 12 |
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
### Audio & Speech Processing
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12, ROCm, Intel SYCL, Vulkan, CPU |
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12, ROCm, Intel, CPU |
| **bark** | Text-to-audio generation | CUDA 12, ROCm, Intel |
| **bark-cpp** | C++ implementation of Bark | CUDA, Metal, CPU |
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12, ROCm, Intel, CPU |
| **kokoro** | Lightweight TTS model | CUDA 12, ROCm, Intel, CPU |
| **chatterbox** | Production-grade TTS | CUDA 11/12, CPU |
| **piper** | Fast neural TTS system | CPU |
| **kitten-tts** | Kitten TTS models | CPU |
| **silero-vad** | Voice Activity Detection | CPU |
### Image & Video Generation
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12, Intel SYCL, Vulkan, CPU |
| **diffusers** | HuggingFace diffusion models | CUDA 11/12, ROCm, Intel, Metal, CPU |
### Specialized AI Tasks
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **rfdetr** | Real-time object detection | CUDA 12, Intel, CPU |
| **rerankers** | Document reranking API | CUDA 11/12, ROCm, Intel, CPU |
| **local-store** | Vector database | CPU |
| **huggingface** | HuggingFace API integration | API-based |
### Hardware Acceleration Matrix
| Acceleration Type | Supported Backends | Hardware Support |
|-------------------|-------------------|------------------|
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark | AMD Graphics |
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
| **NVIDIA Jetson** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI |
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
### 🔗 Community and integrations
@@ -307,9 +247,6 @@ WebUIs:
Model galleries
- https://github.com/go-skynet/model-gallery
Voice:
- https://github.com/richiejp/VoxInput
Other:
- Helm chart https://github.com/go-skynet/helm-charts
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin

View File

@@ -2,10 +2,10 @@ context_size: 4096
f16: true
backend: llama-cpp
mmap: true
mmproj: minicpm-v-4_5-mmproj-f16.gguf
mmproj: minicpm-v-2_6-mmproj-f16.gguf
name: gpt-4o
parameters:
model: minicpm-v-4_5-Q4_K_M.gguf
model: minicpm-v-2_6-Q4_K_M.gguf
stopwords:
- <|im_end|>
- <dummy32000>
@@ -42,9 +42,9 @@ template:
<|im_start|>assistant
download_files:
- filename: minicpm-v-4_5-Q4_K_M.gguf
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-4_5-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
- filename: minicpm-v-2_6-Q4_K_M.gguf
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-2_6-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd

View File

@@ -2,10 +2,10 @@ context_size: 4096
backend: llama-cpp
f16: true
mmap: true
mmproj: minicpm-v-4_5-mmproj-f16.gguf
mmproj: minicpm-v-2_6-mmproj-f16.gguf
name: gpt-4o
parameters:
model: minicpm-v-4_5-Q4_K_M.gguf
model: minicpm-v-2_6-Q4_K_M.gguf
stopwords:
- <|im_end|>
- <dummy32000>
@@ -42,9 +42,9 @@ template:
<|im_start|>assistant
download_files:
- filename: minicpm-v-4_5-Q4_K_M.gguf
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-4_5-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
- filename: minicpm-v-2_6-Q4_K_M.gguf
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-2_6-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd

View File

@@ -2,10 +2,10 @@ context_size: 4096
backend: llama-cpp
f16: true
mmap: true
mmproj: minicpm-v-4_5-mmproj-f16.gguf
mmproj: minicpm-v-2_6-mmproj-f16.gguf
name: gpt-4o
parameters:
model: minicpm-v-4_5-Q4_K_M.gguf
model: minicpm-v-2_6-Q4_K_M.gguf
stopwords:
- <|im_end|>
- <dummy32000>
@@ -43,9 +43,9 @@ template:
download_files:
- filename: minicpm-v-4_5-Q4_K_M.gguf
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-4_5-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
- filename: minicpm-v-2_6-Q4_K_M.gguf
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
- filename: minicpm-v-2_6-mmproj-f16.gguf
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd

View File

@@ -1,213 +0,0 @@
# LocalAI Backend Architecture
This directory contains the core backend infrastructure for LocalAI, including the gRPC protocol definition, multi-language Dockerfiles, and language-specific backend implementations.
## Overview
LocalAI uses a unified gRPC-based architecture that allows different programming languages to implement AI backends while maintaining consistent interfaces and capabilities. The backend system supports multiple hardware acceleration targets and provides a standardized way to integrate various AI models and frameworks.
## Architecture Components
### 1. Protocol Definition (`backend.proto`)
The `backend.proto` file defines the gRPC service interface that all backends must implement. This ensures consistency across different language implementations and provides a contract for communication between LocalAI core and backend services.
#### Core Services
- **Text Generation**: `Predict`, `PredictStream` for LLM inference
- **Embeddings**: `Embedding` for text vectorization
- **Image Generation**: `GenerateImage` for stable diffusion and image models
- **Audio Processing**: `AudioTranscription`, `TTS`, `SoundGeneration`
- **Video Generation**: `GenerateVideo` for video synthesis
- **Object Detection**: `Detect` for computer vision tasks
- **Vector Storage**: `StoresSet`, `StoresGet`, `StoresFind` for RAG operations
- **Reranking**: `Rerank` for document relevance scoring
- **Voice Activity Detection**: `VAD` for audio segmentation
#### Key Message Types
- **`PredictOptions`**: Comprehensive configuration for text generation
- **`ModelOptions`**: Model loading and configuration parameters
- **`Result`**: Standardized response format
- **`StatusResponse`**: Backend health and memory usage information
### 2. Multi-Language Dockerfiles
The backend system provides language-specific Dockerfiles that handle the build environment and dependencies for different programming languages:
- `Dockerfile.python`
- `Dockerfile.golang`
- `Dockerfile.llama-cpp`
### 3. Language-Specific Implementations
#### Python Backends (`python/`)
- **transformers**: Hugging Face Transformers framework
- **vllm**: High-performance LLM inference
- **mlx**: Apple Silicon optimization
- **diffusers**: Stable Diffusion models
- **Audio**: bark, coqui, faster-whisper, kitten-tts
- **Vision**: mlx-vlm, rfdetr
- **Specialized**: rerankers, chatterbox, kokoro
#### Go Backends (`go/`)
- **whisper**: OpenAI Whisper speech recognition in Go with GGML cpp backend (whisper.cpp)
- **stablediffusion-ggml**: Stable Diffusion in Go with GGML Cpp backend
- **huggingface**: Hugging Face model integration
- **piper**: Text-to-speech synthesis Golang with C bindings using rhaspy/piper
- **bark-cpp**: Bark TTS models Golang with Cpp bindings
- **local-store**: Vector storage backend
#### C++ Backends (`cpp/`)
- **llama-cpp**: Llama.cpp integration
- **grpc**: GRPC utilities and helpers
## Hardware Acceleration Support
### CUDA (NVIDIA)
- **Versions**: CUDA 11.x, 12.x
- **Features**: cuBLAS, cuDNN, TensorRT optimization
- **Targets**: x86_64, ARM64 (Jetson)
### ROCm (AMD)
- **Features**: HIP, rocBLAS, MIOpen
- **Targets**: AMD GPUs with ROCm support
### Intel
- **Features**: oneAPI, Intel Extension for PyTorch
- **Targets**: Intel GPUs, XPUs, CPUs
### Vulkan
- **Features**: Cross-platform GPU acceleration
- **Targets**: Windows, Linux, Android, macOS
### Apple Silicon
- **Features**: MLX framework, Metal Performance Shaders
- **Targets**: M1/M2/M3 Macs
## Backend Registry (`index.yaml`)
The `index.yaml` file serves as a central registry for all available backends, providing:
- **Metadata**: Name, description, license, icons
- **Capabilities**: Hardware targets and optimization profiles
- **Tags**: Categorization for discovery
- **URLs**: Source code and documentation links
## Building Backends
### Prerequisites
- Docker with multi-architecture support
- Appropriate hardware drivers (CUDA, ROCm, etc.)
- Build tools (make, cmake, compilers)
### Build Commands
Example of build commands with Docker
```bash
# Build Python backend
docker build -f backend/Dockerfile.python \
--build-arg BACKEND=transformers \
--build-arg BUILD_TYPE=cublas12 \
--build-arg CUDA_MAJOR_VERSION=12 \
--build-arg CUDA_MINOR_VERSION=0 \
-t localai-backend-transformers .
# Build Go backend
docker build -f backend/Dockerfile.golang \
--build-arg BACKEND=whisper \
--build-arg BUILD_TYPE=cpu \
-t localai-backend-whisper .
# Build C++ backend
docker build -f backend/Dockerfile.llama-cpp \
--build-arg BACKEND=llama-cpp \
--build-arg BUILD_TYPE=cublas12 \
-t localai-backend-llama-cpp .
```
For ARM64/Mac builds, docker can't be used, and the makefile in the respective backend has to be used.
### Build Types
- **`cpu`**: CPU-only optimization
- **`cublas11`**: CUDA 11.x with cuBLAS
- **`cublas12`**: CUDA 12.x with cuBLAS
- **`hipblas`**: ROCm with rocBLAS
- **`intel`**: Intel oneAPI optimization
- **`vulkan`**: Vulkan-based acceleration
- **`metal`**: Apple Metal optimization
## Backend Development
### Creating a New Backend
1. **Choose Language**: Select Python, Go, or C++ based on requirements
2. **Implement Interface**: Implement the gRPC service defined in `backend.proto`
3. **Add Dependencies**: Create appropriate requirements files
4. **Configure Build**: Set up Dockerfile and build scripts
5. **Register Backend**: Add entry to `index.yaml`
6. **Test Integration**: Verify gRPC communication and functionality
### Backend Structure
```
backend-name/
├── backend.py/go/cpp # Main implementation
├── requirements.txt # Dependencies
├── Dockerfile # Build configuration
├── install.sh # Installation script
├── run.sh # Execution script
├── test.sh # Test script
└── README.md # Backend documentation
```
### Required gRPC Methods
At minimum, backends must implement:
- `Health()` - Service health check
- `LoadModel()` - Model loading and initialization
- `Predict()` - Main inference endpoint
- `Status()` - Backend status and metrics
## Integration with LocalAI Core
Backends communicate with LocalAI core through gRPC:
1. **Service Discovery**: Core discovers available backends
2. **Model Loading**: Core requests model loading via `LoadModel`
3. **Inference**: Core sends requests via `Predict` or specialized endpoints
4. **Streaming**: Core handles streaming responses for real-time generation
5. **Monitoring**: Core tracks backend health and performance
## Performance Optimization
### Memory Management
- **Model Caching**: Efficient model loading and caching
- **Batch Processing**: Optimize for multiple concurrent requests
- **Memory Pinning**: GPU memory optimization for CUDA/ROCm
### Hardware Utilization
- **Multi-GPU**: Support for tensor parallelism
- **Mixed Precision**: FP16/BF16 for memory efficiency
- **Kernel Fusion**: Optimized CUDA/ROCm kernels
## Troubleshooting
### Common Issues
1. **GRPC Connection**: Verify backend service is running and accessible
2. **Model Loading**: Check model paths and dependencies
3. **Hardware Detection**: Ensure appropriate drivers and libraries
4. **Memory Issues**: Monitor GPU memory usage and model sizes
## Contributing
When contributing to the backend system:
1. **Follow Protocol**: Implement the exact gRPC interface
2. **Add Tests**: Include comprehensive test coverage
3. **Document**: Provide clear usage examples
4. **Optimize**: Consider performance and resource usage
5. **Validate**: Test across different hardware targets

View File

@@ -242,7 +242,7 @@ message ModelOptions {
string Type = 49;
string FlashAttention = 56;
bool FlashAttention = 56;
bool NoKVOffload = 57;
string ModelPath = 59;
@@ -276,7 +276,6 @@ message TranscriptRequest {
string language = 3;
uint32 threads = 4;
bool translate = 5;
bool diarize = 6;
}
message TranscriptResult {
@@ -306,24 +305,22 @@ message GenerateImageRequest {
// Diffusers
string EnableParameters = 10;
int32 CLIPSkip = 11;
// Reference images for models that support them (e.g., Flux Kontext)
repeated string ref_images = 12;
}
message GenerateVideoRequest {
string prompt = 1;
string negative_prompt = 2; // Negative prompt for video generation
string start_image = 3; // Path or base64 encoded image for the start frame
string end_image = 4; // Path or base64 encoded image for the end frame
int32 width = 5;
int32 height = 6;
int32 num_frames = 7; // Number of frames to generate
int32 fps = 8; // Frames per second
int32 seed = 9;
float cfg_scale = 10; // Classifier-free guidance scale
int32 step = 11; // Number of inference steps
string dst = 12; // Output path for the generated video
string start_image = 2; // Path or base64 encoded image for the start frame
string end_image = 3; // Path or base64 encoded image for the end frame
int32 width = 4;
int32 height = 5;
int32 num_frames = 6; // Number of frames to generate
int32 fps = 7; // Frames per second
int32 seed = 8;
float cfg_scale = 9; // Classifier-free guidance scale
string dst = 10; // Output path for the generated video
}
message TTSRequest {

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=3edd87cd055a45d885fa914d879d36d33ecfc3e1
LLAMA_VERSION?=710dfc465a68f7443b87d9f792cffba00ed739fe
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -304,15 +304,7 @@ static void params_parse(const backend::ModelOptions* request,
}
params.use_mlock = request->mlock();
params.use_mmap = request->mmap();
if (request->flashattention() == "on" || request->flashattention() == "enabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (request->flashattention() == "auto") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
}
params.flash_attn = request->flashattention();
params.no_kv_offload = request->nokvoffload();
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
@@ -701,7 +693,7 @@ public:
*/
// for the shape of input/content, see tokenize_input_prompts()
json prompt = body.at("embeddings");
json prompt = body.at("prompt");
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
@@ -712,7 +704,6 @@ public:
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
// create and queue the task
json responses = json::array();
bool error = false;
@@ -726,8 +717,9 @@ public:
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
task.params.embd_normalize = embd_normalize;
// OAI-compat
task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING;
tasks.push_back(std::move(task));
}
@@ -743,8 +735,9 @@ public:
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
error = true;
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
}, [&]() {
// NOTE: we should try to check when the writer is closed here
return false;
});
@@ -754,36 +747,12 @@ public:
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
// Process the responses and extract embeddings
for (const auto & response_elem : responses) {
// Check if the response has an "embedding" field
if (response_elem.contains("embedding")) {
json embedding_data = json_value(response_elem, "embedding", json::array());
if (embedding_data.is_array() && !embedding_data.empty()) {
for (const auto & embedding_vector : embedding_data) {
if (embedding_vector.is_array()) {
for (const auto & embedding_value : embedding_vector) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
}
} else {
// Check if the response itself contains the embedding data directly
if (response_elem.is_array()) {
for (const auto & embedding_value : response_elem) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
std::vector<float> embeddings = responses[0].value("embedding", std::vector<float>());
// loop the vector and set the embeddings results
for (int i = 0; i < embeddings.size(); i++) {
embeddingResult->add_embeddings(embeddings[i]);
}
return grpc::Status::OK;
}

View File

@@ -1,6 +0,0 @@
package/
sources/
.cache/
build/
libgosd.so
stablediffusion-ggml

View File

@@ -5,11 +5,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_subdirectory(./sources/stablediffusion-ggml.cpp)
add_library(gosd MODULE gosd.cpp)
target_link_libraries(gosd PRIVATE stable-diffusion ggml)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
target_link_libraries(gosd PRIVATE stdc++fs)
endif()
target_link_libraries(gosd PRIVATE stable-diffusion ggml stdc++fs)
target_include_directories(gosd PUBLIC
stable-diffusion.cpp

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=0ebe6fe118f125665939b27c89f34ed38716bff8
STABLEDIFFUSION_GGML_VERSION?=5900ef6605c6fbf7934239f795c13c97bc993853
CMAKE_ARGS+=-DGGML_MAX_NAME=128
@@ -29,6 +29,8 @@ else ifeq ($(BUILD_TYPE),clblas)
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DSD_HIPBLAS=ON -DGGML_HIPBLAS=ON
# If it's OSX, DO NOT embed the metal library - -DGGML_METAL_EMBED_LIBRARY=ON requires further investigation
# But if it's OSX without metal, disable it here
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DSD_VULKAN=ON -DGGML_VULKAN=ON
else ifeq ($(OS),Darwin)
@@ -72,10 +74,10 @@ libgosd.so: sources/stablediffusion-ggml.cpp CMakeLists.txt gosd.cpp gosd.h
stablediffusion-ggml: main.go gosd.go libgosd.so
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
package: stablediffusion-ggml
package:
bash package.sh
build: package
build: stablediffusion-ggml package
clean:
rm -rf libgosd.so build stablediffusion-ggml package sources
rm -rf libgosd.o build stablediffusion-ggml

View File

@@ -4,11 +4,17 @@
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include <filesystem>
#include "gosd.h"
// #include "preprocessing.hpp"
#include "flux.hpp"
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
@@ -23,7 +29,7 @@
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
"default",
"euler_a",
"euler",
"heun",
"dpm2",
@@ -35,27 +41,19 @@ const char* sample_method_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"euler_a",
};
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedulers[] = {
const char* schedule_str[] = {
"default",
"discrete",
"karras",
"exponential",
"ays",
"gits",
"smoothstep",
};
static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch");
sd_ctx_t* sd_c;
// Moved from the context (load time) to generation time params
scheduler_t scheduler = scheduler_t::DEFAULT;
sample_method_t sample_method;
@@ -107,7 +105,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *clip_g_path = "";
const char *t5xxl_path = "";
const char *vae_path = "";
const char *scheduler_str = "";
const char *scheduler = "";
const char *sampler = "";
char *lora_dir = model_path;
bool lora_dir_allocated = false;
@@ -135,7 +133,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
vae_path = optval;
}
if (!strcmp(optname, "scheduler")) {
scheduler_str = optval;
scheduler = optval;
}
if (!strcmp(optname, "sampler")) {
sampler = optval;
@@ -168,17 +166,26 @@ int load_model(const char *model, char *model_path, char* options[], int threads
}
if (sample_method_found == -1) {
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT;
sample_method_found = EULER_A;
}
sample_method = (sample_method_t)sample_method_found;
int schedule_found = -1;
for (int d = 0; d < SCHEDULE_COUNT; d++) {
if (!strcmp(scheduler_str, schedulers[d])) {
scheduler = (scheduler_t)d;
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
if (!strcmp(scheduler, schedule_str[d])) {
schedule_found = d;
fprintf (stderr, "Found scheduler: %s\n", scheduler);
}
}
if (schedule_found == -1) {
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
schedule_found = DEFAULT;
}
schedule_t schedule = (schedule_t)schedule_found;
fprintf (stderr, "Creating context\n");
sd_ctx_params_t ctx_params;
sd_ctx_params_init(&ctx_params);
@@ -192,10 +199,13 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.control_net_path = "";
ctx_params.lora_model_dir = lora_dir;
ctx_params.embedding_dir = "";
ctx_params.stacked_id_embed_dir = "";
ctx_params.vae_decode_only = false;
ctx_params.vae_tiling = false;
ctx_params.free_params_immediately = false;
ctx_params.n_threads = threads;
ctx_params.rng_type = STD_DEFAULT_RNG;
ctx_params.schedule = schedule;
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
if (sd_ctx == NULL) {
@@ -218,49 +228,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
return 0;
}
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled) {
params->enabled = enabled;
}
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y) {
params->tile_size_x = tile_size_x;
params->tile_size_y = tile_size_y;
}
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y) {
params->rel_size_x = rel_size_x;
params->rel_size_y = rel_size_y;
}
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap) {
params->target_overlap = target_overlap;
}
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params) {
return &params->vae_tiling_params;
}
sd_img_gen_params_t* sd_img_gen_params_new(void) {
sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc(sizeof(sd_img_gen_params_t));
sd_img_gen_params_init(params);
return params;
}
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
params->prompt = prompt;
params->negative_prompt = negative_prompt;
}
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
params->width = width;
params->height = height;
}
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) {
params->seed = seed;
}
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
sd_image_t* results;
@@ -268,15 +236,20 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
fprintf (stderr, "Generating image\n");
p->sample_params.guidance.txt_cfg = cfg_scale;
p->sample_params.guidance.slg.layers = skip_layers.data();
p->sample_params.guidance.slg.layer_count = skip_layers.size();
p->sample_params.sample_method = sample_method;
p->sample_params.sample_steps = steps;
p->sample_params.scheduler = scheduler;
sd_img_gen_params_t p;
sd_img_gen_params_init(&p);
int width = p->width;
int height = p->height;
p.prompt = text;
p.negative_prompt = negativeText;
p.guidance.txt_cfg = cfg_scale;
p.guidance.slg.layers = skip_layers.data();
p.guidance.slg.layer_count = skip_layers.size();
p.width = width;
p.height = height;
p.sample_method = sample_method;
p.sample_steps = steps;
p.seed = seed;
p.input_id_images_path = "";
// Handle input image for img2img
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
@@ -325,13 +298,13 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
input_image_buffer = resized_image_buffer;
}
p->init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
p->strength = strength;
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
p.strength = strength;
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
} else {
// No input image, use empty image for text-to-image
p->init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
p->strength = 0.0f;
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
p.strength = 0.0f;
}
// Handle mask image for inpainting
@@ -371,12 +344,12 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
mask_image_buffer = resized_mask_buffer;
}
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
fprintf(stderr, "Using inpainting with mask\n");
} else {
// No mask image, create default full mask
default_mask_image_vec.resize(width * height, 255);
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
}
// Handle reference images
@@ -434,15 +407,13 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
}
if (!ref_images_vec.empty()) {
p->ref_images = ref_images_vec.data();
p->ref_images_count = ref_images_vec.size();
p.ref_images = ref_images_vec.data();
p.ref_images_count = ref_images_vec.size();
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
}
}
results = generate_image(sd_c, p);
std::free(p);
results = generate_image(sd_c, &p);
if (results == NULL) {
fprintf (stderr, "NO results\n");

View File

@@ -22,18 +22,7 @@ type SDGGML struct {
var (
LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
TilingParamsSetEnabled func(params uintptr, enabled bool)
TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int)
TilingParamsSetRelSizes func(params uintptr, relSizeX float32, relSizeY float32)
TilingParamsSetTargetOverlap func(params uintptr, targetOverlap float32)
ImgGenParamsNew func() uintptr
ImgGenParamsSetPrompts func(params uintptr, prompt string, negativePrompt string)
ImgGenParamsSetDimensions func(params uintptr, width int, height int)
ImgGenParamsSetSeed func(params uintptr, seed int64)
ImgGenParamsGetVaeTilingParams func(params uintptr) uintptr
GenImage func(text, negativeText string, width, height, steps int, seed int64, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
)
// Copied from Purego internal/strings
@@ -131,15 +120,7 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
// Default strength for img2img (0.75 is a good default)
strength := float32(0.75)
// free'd by GenImage
p := ImgGenParamsNew()
ImgGenParamsSetPrompts(p, t, negative)
ImgGenParamsSetDimensions(p, int(opts.Width), int(opts.Height))
ImgGenParamsSetSeed(p, int64(opts.Seed))
vaep := ImgGenParamsGetVaeTilingParams(p)
TilingParamsSetEnabled(vaep, false)
ret := GenImage(p, int(opts.Step), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
ret := GenImage(t, negative, int(opts.Width), int(opts.Height), int(opts.Step), int64(opts.Seed), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
if ret != 0 {
return fmt.Errorf("inference failed")
}

View File

@@ -1,23 +1,8 @@
#include <cstdint>
#include "stable-diffusion.h"
#ifdef __cplusplus
extern "C" {
#endif
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled);
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y);
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y);
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap);
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params);
sd_img_gen_params_t* sd_img_gen_params_new(void);
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt);
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height);
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed);
int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
#ifdef __cplusplus
}
#endif

View File

@@ -11,35 +11,14 @@ var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
func main() {
gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
panic(err)
}
libFuncs := []LibFuncs{
{&LoadModel, "load_model"},
{&GenImage, "gen_image"},
{&TilingParamsSetEnabled, "sd_tiling_params_set_enabled"},
{&TilingParamsSetTileSizes, "sd_tiling_params_set_tile_sizes"},
{&TilingParamsSetRelSizes, "sd_tiling_params_set_rel_sizes"},
{&TilingParamsSetTargetOverlap, "sd_tiling_params_set_target_overlap"},
{&ImgGenParamsNew, "sd_img_gen_params_new"},
{&ImgGenParamsSetPrompts, "sd_img_gen_params_set_prompts"},
{&ImgGenParamsSetDimensions, "sd_img_gen_params_set_dimensions"},
{&ImgGenParamsSetSeed, "sd_img_gen_params_set_seed"},
{&ImgGenParamsGetVaeTilingParams, "sd_img_gen_params_get_vae_tiling_params"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
}
purego.RegisterLibFunc(&LoadModel, gosd, "load_model")
purego.RegisterLibFunc(&GenImage, gosd, "gen_image")
flag.Parse()

View File

@@ -10,9 +10,9 @@ CURDIR=$(dirname "$(realpath $0)")
# Create lib directory
mkdir -p $CURDIR/package/lib
cp -avf $CURDIR/libgosd.so $CURDIR/package/
cp -avf $CURDIR/stablediffusion-ggml $CURDIR/package/
cp -fv $CURDIR/run.sh $CURDIR/package/
cp -avrf $CURDIR/libgosd.so $CURDIR/package/
cp -avrf $CURDIR/stablediffusion-ggml $CURDIR/package/
cp -rfv $CURDIR/run.sh $CURDIR/package/
# Detect architecture and copy appropriate libraries
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
@@ -43,8 +43,6 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ $(uname -s) = "Darwin" ]; then
echo "Detected Darwin"
else
echo "Error: Could not detect architecture"
exit 1

View File

@@ -1,7 +0,0 @@
.cache/
sources/
build/
package/
whisper
libgowhisper.so

View File

@@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.12)
project(gowhisper LANGUAGES C CXX)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
add_subdirectory(./sources/whisper.cpp)
add_library(gowhisper MODULE gowhisper.cpp)
target_link_libraries(gowhisper PRIVATE whisper ggml)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
target_link_libraries(gosd PRIVATE stdc++fs)
endif()
set_property(TARGET gowhisper PROPERTY CXX_STANDARD 17)
set_target_properties(gowhisper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

View File

@@ -1,53 +1,110 @@
CMAKE_ARGS?=
BUILD_TYPE?=
GOCMD=go
NATIVE?=false
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc --ignore=1)
BUILD_TYPE?=
CMAKE_ARGS?=
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=edea8a9c3cf0eb7676dcdb604991eb2f95c3d984
WHISPER_CPP_VERSION?=fc45bb86251f774ef817e89878bb4c2636c8a58f
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
export WHISPER_INCLUDE_PATH=$(WHISPER_DIR)/include:$(WHISPER_DIR)/ggml/include
export WHISPER_LIBRARY_PATH=$(WHISPER_DIR)/build/src/:$(WHISPER_DIR)/build/ggml/src
CGO_LDFLAGS_WHISPER?=
CGO_LDFLAGS_WHISPER+=-lggml
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
CUDA_LIBPATH?=/usr/local/cuda/lib64/
ONEAPI_VERSION?=2025.2
# IF native is false, we add -DGGML_NATIVE=OFF to CMAKE_ARGS
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
WHISPER_CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
ifeq ($(BUILD_TYPE),cublas)
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH) -L$(CUDA_LIBPATH)/stubs/ -lcuda
CMAKE_ARGS+=-DGGML_CUDA=ON
CGO_LDFLAGS_WHISPER+=-lcufft -lggml-cuda
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-cuda/
# If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
# to CMAKE_ARGS automatically
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
else ifeq ($(BUILD_TYPE),clblas)
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
ROCM_HOME ?= /opt/rocm
ROCM_PATH ?= /opt/rocm
LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib
export STABLE_BUILD_TYPE=
export CXX=$(ROCM_HOME)/llvm/bin/clang++
export CC=$(ROCM_HOME)/llvm/bin/clang
# GPU_TARGETS ?= gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102
# AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
CMAKE_ARGS+=-DGGML_HIP=ON
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib -L$(CURRENT_MAKEFILE_DIR)/sources/whisper.cpp/build/ggml/src/ggml-hip/ -lggml-hip
# CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DGGML_VULKAN=ON
CMAKE_ARGS+=-DGGML_VULKAN=1
CGO_LDFLAGS_WHISPER+=-lggml-vulkan -lvulkan
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-vulkan/
else ifeq ($(OS),Darwin)
ifeq ($(BUILD_TYPE),)
BUILD_TYPE=metal
endif
ifneq ($(BUILD_TYPE),metal)
CMAKE_ARGS+=-DGGML_METAL=OFF
CGO_LDFLAGS_WHISPER+=-lggml-blas
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-blas
else
CMAKE_ARGS+=-DGGML_METAL=ON
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
CMAKE_ARGS+=-DGGML_METAL_USE_BF16=ON
CMAKE_ARGS+=-DGGML_OPENMP=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_EXAMPLES=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_TESTS=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_SERVER=OFF
CGO_LDFLAGS += -framework Accelerate
CGO_LDFLAGS_WHISPER+=-lggml-metal -lggml-blas
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-metal/:$(WHISPER_DIR)/build/ggml/src/ggml-blas
endif
TARGET+=--target ggml-metal
endif
ifeq ($(BUILD_TYPE),sycl_f16)
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
export CC=icx
export CXX=icpx
CGO_LDFLAGS_WHISPER += -fsycl -L${DNNLROOT}/lib -rpath ${ONEAPI_ROOT}/${ONEAPI_VERSION}/lib -ldnnl ${MKLROOT}/lib/intel64/libmkl_sycl.a -fiopenmp -fopenmp-targets=spir64 -lOpenCL -lggml-sycl
CGO_LDFLAGS_WHISPER += $(shell pkg-config --libs mkl-static-lp64-gomp)
CGO_CXXFLAGS_WHISPER += -fiopenmp -fopenmp-targets=spir64
CGO_CXXFLAGS_WHISPER += $(shell pkg-config --cflags mkl-static-lp64-gomp )
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-sycl/
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
-DGGML_SYCL_F16=ON
-DCMAKE_CXX_FLAGS="-fsycl"
endif
ifeq ($(BUILD_TYPE),sycl_f32)
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
ifeq ($(BUILD_TYPE),sycl_f16)
CMAKE_ARGS+=-DGGML_SYCL_F16=ON
endif
ifneq ($(OS),Darwin)
CGO_LDFLAGS_WHISPER+=-lgomp
endif
## whisper
sources/whisper.cpp:
mkdir -p sources/whisper.cpp
cd sources/whisper.cpp && \
@@ -57,21 +114,18 @@ sources/whisper.cpp:
git checkout $(WHISPER_CPP_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch
libgowhisper.so: sources/whisper.cpp CMakeLists.txt gowhisper.cpp gowhisper.h
mkdir -p build && \
cd build && \
cmake .. $(CMAKE_ARGS) && \
cmake --build . --config Release -j$(JOBS) && \
cd .. && \
mv build/libgowhisper.so ./
sources/whisper.cpp/build/src/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && cmake $(CMAKE_ARGS) $(WHISPER_CMAKE_ARGS) . -B ./build
cd sources/whisper.cpp/build && cmake --build . --config Release
whisper: main.go gowhisper.go libgowhisper.so
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
whisper: sources/whisper.cpp sources/whisper.cpp/build/src/libwhisper.a
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(CURDIR)/sources/whisper.cpp/bindings/go
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="${WHISPER_INCLUDE_PATH}" LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" LD_LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" \
CGO_CXXFLAGS="$(CGO_CXXFLAGS_WHISPER)" \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o whisper ./
package: whisper
package:
bash package.sh
build: package
clean:
rm -rf libgowhisper.o build whisper
build: whisper package

View File

@@ -1,154 +0,0 @@
#include "gowhisper.h"
#include "ggml-backend.h"
#include "whisper.h"
#include <vector>
static struct whisper_vad_context *vctx;
static struct whisper_context *ctx;
static std::vector<float> flat_segs;
static void ggml_log_cb(enum ggml_log_level level, const char *log,
void *data) {
const char *level_str;
if (!log) {
return;
}
switch (level) {
case GGML_LOG_LEVEL_DEBUG:
level_str = "DEBUG";
break;
case GGML_LOG_LEVEL_INFO:
level_str = "INFO";
break;
case GGML_LOG_LEVEL_WARN:
level_str = "WARN";
break;
case GGML_LOG_LEVEL_ERROR:
level_str = "ERROR";
break;
default: /* Potential future-proofing */
level_str = "?????";
break;
}
fprintf(stderr, "[%-5s] ", level_str);
fputs(log, stderr);
fflush(stderr);
}
int load_model(const char *const model_path) {
whisper_log_set(ggml_log_cb, nullptr);
ggml_backend_load_all();
struct whisper_context_params cparams = whisper_context_default_params();
ctx = whisper_init_from_file_with_params(model_path, cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: Also failed to init model as transcriber\n");
return 1;
}
return 0;
}
int load_model_vad(const char *const model_path) {
whisper_log_set(ggml_log_cb, nullptr);
ggml_backend_load_all();
struct whisper_vad_context_params vcparams =
whisper_vad_default_context_params();
// XXX: Overridden to false in upstream due to performance?
// vcparams.use_gpu = true;
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
if (vctx == nullptr) {
fprintf(stderr, "error: Failed to init model as VAD\n");
return 1;
}
return 0;
}
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
size_t *segs_out_len) {
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
fprintf(stderr, "error: failed to detect speech\n");
return 1;
}
struct whisper_vad_params params = whisper_vad_default_params();
struct whisper_vad_segments *segs =
whisper_vad_segments_from_probs(vctx, params);
size_t segn = whisper_vad_segments_n_segments(segs);
// fprintf(stderr, "Got segments %zd\n", segn);
flat_segs.clear();
for (int i = 0; i < segn; i++) {
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
}
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
// flat_segs.size());
*segs_out = flat_segs.data();
*segs_out_len = flat_segs.size();
// fprintf(stderr, "freeing segs\n");
whisper_vad_free_segments(segs);
// fprintf(stderr, "returning\n");
return 0;
}
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
whisper_full_params wparams =
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.n_threads = threads;
if (*lang != '\0')
wparams.language = lang;
else {
wparams.language = nullptr;
}
wparams.translate = translate;
wparams.debug_mode = true;
wparams.print_progress = true;
wparams.tdrz_enable = tdrz;
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
fprintf(stderr, "error: transcription failed\n");
return 1;
}
*segs_out_len = whisper_full_n_segments(ctx);
return 0;
}
const char *get_segment_text(int i) {
return whisper_full_get_segment_text(ctx, i);
}
int64_t get_segment_t0(int i) { return whisper_full_get_segment_t0(ctx, i); }
int64_t get_segment_t1(int i) { return whisper_full_get_segment_t1(ctx, i); }
int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
int32_t get_token_id(int i, int j) {
return whisper_full_get_token_id(ctx, i, j);
}
bool get_segment_speaker_turn_next(int i) {
return whisper_full_get_segment_speaker_turn_next(ctx, i);
}

View File

@@ -1,161 +0,0 @@
package main
import (
"fmt"
"os"
"path/filepath"
"strings"
"unsafe"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
var (
CppLoadModel func(modelPath string) int
CppLoadModelVAD func(modelPath string) int
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
CppGetSegmentText func(i int) string
CppGetSegmentStart func(i int) int64
CppGetSegmentEnd func(i int) int64
CppNTokens func(i int) int
CppGetTokenID func(i int, j int) int
CppGetSegmentSpeakerTurnNext func(i int) bool
)
type Whisper struct {
base.SingleThread
}
func (w *Whisper) Load(opts *pb.ModelOptions) error {
vadOnly := false
for _, oo := range opts.Options {
if oo == "vad_only" {
vadOnly = true
} else {
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
}
}
if vadOnly {
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
return fmt.Errorf("Failed to load Whisper VAD model")
}
return nil
}
if ret := CppLoadModel(opts.ModelFile); ret != 0 {
return fmt.Errorf("Failed to load Whisper transcription model")
}
return nil
}
func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
audio := req.Audio
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
}
// Happens when CPP vector has not had any elements pushed to it
if segsPtr == 0 {
return pb.VADResponse{
Segments: []*pb.VADSegment{},
}, nil
}
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen)
vadSegments := []*pb.VADSegment{}
for i := range len(segs) >> 1 {
s := segs[2*i] / 100
t := segs[2*i+1] / 100
vadSegments = append(vadSegments, &pb.VADSegment{
Start: s,
End: t,
})
}
return pb.VADResponse{
Segments: vadSegments,
}, nil
}
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}
data := buf.AsFloat32Buffer().Data
segsLen := uintptr(0xdeadbeef)
segsLenPtr := unsafe.Pointer(&segsLen)
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
}
segments := []*pb.TranscriptSegment{}
text := ""
for i := range int(segsLen) {
s := CppGetSegmentStart(i)
t := CppGetSegmentEnd(i)
txt := strings.Clone(CppGetSegmentText(i))
tokens := make([]int32, CppNTokens(i))
if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
txt += " [SPEAKER_TURN]"
}
for j := range tokens {
tokens[j] = int32(CppGetTokenID(i, j))
}
segment := &pb.TranscriptSegment{
Id: int32(i),
Text: txt,
Start: s, End: t,
Tokens: tokens,
}
segments = append(segments, segment)
text += " " + strings.TrimSpace(txt)
}
return pb.TranscriptResult{
Segments: segments,
Text: strings.TrimSpace(text),
}, nil
}

View File

@@ -1,17 +0,0 @@
#include <cstddef>
#include <cstdint>
extern "C" {
int load_model(const char *const model_path);
int load_model_vad(const char *const model_path);
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
size_t *segs_out_len);
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
const char *get_segment_text(int i);
int64_t get_segment_t0(int i);
int64_t get_segment_t1(int i);
int n_tokens(int i);
int32_t get_token_id(int i, int j);
bool get_segment_speaker_turn_next(int i);
}

View File

@@ -1,10 +1,10 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
@@ -12,34 +12,7 @@ var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
func main() {
gosd, err := purego.Dlopen("./libgowhisper.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
panic(err)
}
libFuncs := []LibFuncs{
{&CppLoadModel, "load_model"},
{&CppLoadModelVAD, "load_model_vad"},
{&CppVAD, "vad"},
{&CppTranscribe, "transcribe"},
{&CppGetSegmentText, "get_segment_text"},
{&CppGetSegmentStart, "get_segment_t0"},
{&CppGetSegmentEnd, "get_segment_t1"},
{&CppNTokens, "n_tokens"},
{&CppGetTokenID, "get_token_id"},
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
}
flag.Parse()
if err := grpc.StartServer(*addr, &Whisper{}); err != nil {

View File

@@ -10,8 +10,8 @@ CURDIR=$(dirname "$(realpath $0)")
# Create lib directory
mkdir -p $CURDIR/package/lib
cp -avf $CURDIR/whisper $CURDIR/libgowhisper.so $CURDIR/package/
cp -fv $CURDIR/run.sh $CURDIR/package/
cp -avrf $CURDIR/whisper $CURDIR/package/
cp -rfv $CURDIR/run.sh $CURDIR/package/
# Detect architecture and copy appropriate libraries
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
@@ -42,13 +42,11 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ $(uname -s) = "Darwin" ]; then
echo "Detected Darwin"
else
echo "Error: Could not detect architecture"
exit 1
fi
echo "Packaging completed successfully"
echo "Packaging completed successfully"
ls -liah $CURDIR/package/
ls -liah $CURDIR/package/lib/
ls -liah $CURDIR/package/lib/

View File

@@ -0,0 +1,105 @@
package main
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"os"
"path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
type Whisper struct {
base.SingleThread
whisper whisper.Model
}
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.ModelFile)
sd.whisper = w
return err
}
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}
data := buf.AsFloat32Buffer().Data
// Process samples
context, err := sd.whisper.NewContext()
if err != nil {
return pb.TranscriptResult{}, err
}
context.SetThreads(uint(opts.Threads))
if opts.Language != "" {
context.SetLanguage(opts.Language)
} else {
context.SetLanguage("auto")
}
if opts.Translate {
context.SetTranslate(true)
}
if err := context.Process(data, nil, nil, nil); err != nil {
return pb.TranscriptResult{}, err
}
segments := []*pb.TranscriptSegment{}
text := ""
for {
s, err := context.NextSegment()
if err != nil {
break
}
var tokens []int32
for _, t := range s.Tokens {
tokens = append(tokens, int32(t.Id))
}
segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
segments = append(segments, segment)
text += s.Text
}
return pb.TranscriptResult{
Segments: segments,
Text: text,
}, nil
}

View File

@@ -45,7 +45,6 @@
default: "cpu-whisper"
nvidia: "cuda12-whisper"
intel: "intel-sycl-f16-whisper"
metal: "metal-whisper"
amd: "rocm-whisper"
vulkan: "vulkan-whisper"
nvidia-l4t: "nvidia-l4t-arm64-whisper"
@@ -72,7 +71,7 @@
# amd: "rocm-stablediffusion-ggml"
vulkan: "vulkan-stablediffusion-ggml"
nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml"
metal: "metal-stablediffusion-ggml"
# metal: "metal-stablediffusion-ggml"
# darwin-x86: "darwin-x86-stablediffusion-ggml"
- &rfdetr
name: "rfdetr"
@@ -148,7 +147,7 @@
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm"
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
urls:
- https://github.com/Blaizzy/mlx-vlm
- https://github.com/ml-explore/mlx-vlm
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm
license: MIT
@@ -160,23 +159,6 @@
- vision-language
- LLM
- MLX
- &mlx-audio
name: "mlx-audio"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio"
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
urls:
- https://github.com/Blaizzy/mlx-audio
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-mlx-audio
license: MIT
description: |
Run Audio Models with MLX
tags:
- audio-to-text
- audio-generation
- text-to-audio
- LLM
- MLX
- &rerankers
name: "rerankers"
alias: "rerankers"
@@ -201,6 +183,8 @@
nvidia: "cuda12-transformers"
intel: "intel-transformers"
amd: "rocm-transformers"
metal: "metal-transformers"
default: "cpu-transformers"
- &diffusers
name: "diffusers"
icon: https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg
@@ -350,8 +334,6 @@
alias: "chatterbox"
capabilities:
nvidia: "cuda12-chatterbox"
metal: "metal-chatterbox"
default: "cpu-chatterbox"
- &piper
name: "piper"
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
@@ -435,11 +417,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
- !!merge <<: *mlx-audio
name: "mlx-audio-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
- !!merge <<: *kitten-tts
name: "kitten-tts-development"
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
@@ -582,16 +559,6 @@
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisper"
mirrors:
- localai/localai-backends:latest-cpu-whisper
- !!merge <<: *whispercpp
name: "metal-whisper"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-whisper
- !!merge <<: *whispercpp
name: "metal-whisper-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-whisper
- !!merge <<: *whispercpp
name: "cpu-whisper-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisper"
@@ -678,16 +645,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-stablediffusion-ggml"
mirrors:
- localai/localai-backends:master-cpu-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "metal-stablediffusion-ggml"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-stablediffusion-ggml"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "metal-stablediffusion-ggml-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-stablediffusion-ggml"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "vulkan-stablediffusion-ggml"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-stablediffusion-ggml"
@@ -896,6 +853,28 @@
nvidia: "cuda12-transformers-development"
intel: "intel-transformers-development"
amd: "rocm-transformers-development"
default: "cpu-transformers-development"
metal: "metal-transformers-development"
- !!merge <<: *transformers
name: "cpu-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-transformers"
mirrors:
- localai/localai-backends:latest-cpu-transformers
- !!merge <<: *transformers
name: "cpu-transformers-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-transformers"
mirrors:
- localai/localai-backends:master-cpu-transformers
- !!merge <<: *transformers
name: "metal-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-transformers"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-transformers
- !!merge <<: *transformers
name: "metal-transformers-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-transformers"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-transformers
- !!merge <<: *transformers
name: "cuda12-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-transformers"
@@ -1225,28 +1204,6 @@
name: "chatterbox-development"
capabilities:
nvidia: "cuda12-chatterbox-development"
metal: "metal-chatterbox-development"
default: "cpu-chatterbox-development"
- !!merge <<: *chatterbox
name: "cpu-chatterbox"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
mirrors:
- localai/localai-backends:latest-cpu-chatterbox
- !!merge <<: *chatterbox
name: "cpu-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
mirrors:
- localai/localai-backends:master-cpu-chatterbox
- !!merge <<: *chatterbox
name: "metal-chatterbox"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-chatterbox
- !!merge <<: *chatterbox
name: "metal-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-chatterbox
- !!merge <<: *chatterbox
name: "cuda12-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"

View File

@@ -1,190 +1,38 @@
# Python Backends for LocalAI
# Common commands about conda environment
This directory contains Python-based AI backends for LocalAI, providing support for various AI models and hardware acceleration targets.
## Create a new empty conda environment
## Overview
```
conda create --name <env-name> python=<your version> -y
The Python backends use a unified build system based on `libbackend.sh` that provides:
- **Automatic virtual environment management** with support for both `uv` and `pip`
- **Hardware-specific dependency installation** (CPU, CUDA, Intel, MLX, etc.)
- **Portable Python support** for standalone deployments
- **Consistent backend execution** across different environments
## Available Backends
### Core AI Models
- **transformers** - Hugging Face Transformers framework (PyTorch-based)
- **vllm** - High-performance LLM inference engine
- **mlx** - Apple Silicon optimized ML framework
- **exllama2** - ExLlama2 quantized models
### Audio & Speech
- **bark** - Text-to-speech synthesis
- **coqui** - Coqui TTS models
- **faster-whisper** - Fast Whisper speech recognition
- **kitten-tts** - Lightweight TTS
- **mlx-audio** - Apple Silicon audio processing
- **chatterbox** - TTS model
- **kokoro** - TTS models
### Computer Vision
- **diffusers** - Stable Diffusion and image generation
- **mlx-vlm** - Vision-language models for Apple Silicon
- **rfdetr** - Object detection models
### Specialized
- **rerankers** - Text reranking models
## Quick Start
### Prerequisites
- Python 3.10+ (default: 3.10.18)
- `uv` package manager (recommended) or `pip`
- Appropriate hardware drivers for your target (CUDA, Intel, etc.)
### Installation
Each backend can be installed individually:
```bash
# Navigate to a specific backend
cd backend/python/transformers
# Install dependencies
make transformers
# or
bash install.sh
# Run the backend
make run
# or
bash run.sh
conda create --name autogptq python=3.11 -y
```
### Using the Unified Build System
## To activate the environment
The `libbackend.sh` script provides consistent commands across all backends:
```bash
# Source the library in your backend script
source $(dirname $0)/../common/libbackend.sh
# Install requirements (automatically handles hardware detection)
installRequirements
# Start the backend server
startBackend $@
# Run tests
runUnittests
As of conda 4.4
```
conda activate autogptq
```
## Hardware Targets
The conda version older than 4.4
The build system automatically detects and configures for different hardware:
- **CPU** - Standard CPU-only builds
- **CUDA** - NVIDIA GPU acceleration (supports CUDA 11/12)
- **Intel** - Intel XPU/GPU optimization
- **MLX** - Apple Silicon (M1/M2/M3) optimization
- **HIP** - AMD GPU acceleration
### Target-Specific Requirements
Backends can specify hardware-specific dependencies:
- `requirements.txt` - Base requirements
- `requirements-cpu.txt` - CPU-specific packages
- `requirements-cublas11.txt` - CUDA 11 packages
- `requirements-cublas12.txt` - CUDA 12 packages
- `requirements-intel.txt` - Intel-optimized packages
- `requirements-mps.txt` - Apple Silicon packages
## Configuration Options
### Environment Variables
- `PYTHON_VERSION` - Python version (default: 3.10)
- `PYTHON_PATCH` - Python patch version (default: 18)
- `BUILD_TYPE` - Force specific build target
- `USE_PIP` - Use pip instead of uv (default: false)
- `PORTABLE_PYTHON` - Enable portable Python builds
- `LIMIT_TARGETS` - Restrict backend to specific targets
### Example: CUDA 12 Only Backend
```bash
# In your backend script
LIMIT_TARGETS="cublas12"
source $(dirname $0)/../common/libbackend.sh
```
source activate autogptq
```
### Example: Intel-Optimized Backend
## Install the packages to your environment
```bash
# In your backend script
LIMIT_TARGETS="intel"
source $(dirname $0)/../common/libbackend.sh
Sometimes you need to install the packages from the conda-forge channel
By using `conda`
```
conda install <your-package-name>
conda install -c conda-forge <your package-name>
```
## Development
### Adding a New Backend
1. Create a new directory in `backend/python/`
2. Copy the template structure from `common/template/`
3. Implement your `backend.py` with the required gRPC interface
4. Add appropriate requirements files for your target hardware
5. Use `libbackend.sh` for consistent build and execution
### Testing
```bash
# Run backend tests
make test
# or
bash test.sh
Or by using `pip`
```
### Building
```bash
# Install dependencies
make <backend-name>
# Clean build artifacts
make clean
pip install <your-package-name>
```
## Architecture
Each backend follows a consistent structure:
```
backend-name/
├── backend.py # Main backend implementation
├── requirements.txt # Base dependencies
├── requirements-*.txt # Hardware-specific dependencies
├── install.sh # Installation script
├── run.sh # Execution script
├── test.sh # Test script
├── Makefile # Build targets
└── test.py # Unit tests
```
## Troubleshooting
### Common Issues
1. **Missing dependencies**: Ensure all requirements files are properly configured
2. **Hardware detection**: Check that `BUILD_TYPE` matches your system
3. **Python version**: Verify Python 3.10+ is available
4. **Virtual environment**: Use `ensureVenv` to create/activate environments
## Contributing
When adding new backends or modifying existing ones:
1. Follow the established directory structure
2. Use `libbackend.sh` for consistent behavior
3. Include appropriate requirements files for all target hardware
4. Add comprehensive tests
5. Update this README if adding new backend types

View File

@@ -1,6 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cpu
accelerate
torch==2.6.0
torchaudio==2.6.0
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts

View File

@@ -2,5 +2,5 @@
torch==2.6.0+cu118
torchaudio==2.6.0+cu118
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -1,5 +1,5 @@
torch==2.6.0
torchaudio==2.6.0
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -2,5 +2,5 @@
torch==2.6.0+rocm6.1
torchaudio==2.6.0+rocm6.1
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -3,8 +3,9 @@ intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi
torchaudio==2.3.1+cxx11.abi
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate
oneccl_bind_pt==2.3.100+xpu
optimum[openvino]
setuptools
setuptools
accelerate

View File

@@ -286,8 +286,7 @@ _makeVenvPortable() {
function ensureVenv() {
local interpreter=""
if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -e "$(_portable_python)" ]; then
echo "Using portable Python"
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
ensurePortablePython
interpreter="$(_portable_python)"
else
@@ -385,11 +384,6 @@ function installRequirements() {
requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt")
fi
# This is needed to build wheels that e.g. depends on Python.h
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
fi
for reqFile in ${requirementFiles[@]}; do
if [ -f "${reqFile}" ]; then
echo "starting requirements install for ${reqFile}"

View File

@@ -18,7 +18,7 @@ import backend_pb2_grpc
import grpc
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video
@@ -72,6 +72,13 @@ def is_float(s):
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
# Credits to https://github.com/neggles
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
@@ -177,10 +184,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
key, value = opt.split(":")
# if value is a number, convert it to the appropriate type
if is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif is_int(value):
value = int(value)
self.options[key] = value
# From options, extract if present "torch_dtype" and set it to the appropriate type
@@ -328,32 +334,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torch_dtype=torch.bfloat16)
self.pipe.vae.to(torch.bfloat16)
self.pipe.text_encoder.to(torch.bfloat16)
elif request.PipelineType == "WanPipeline":
# WAN2.2 pipeline requires special VAE handling
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
elif request.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video pipeline
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanImageToVideoPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.img2vid = True # WAN2.2 image-to-video pipeline
if CLIPSKIP and request.CLIPSkip != 0:
self.clip_skip = request.CLIPSkip
@@ -495,24 +475,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"num_inference_steps": steps,
}
# Handle image source: prioritize RefImages over request.src
image_src = None
if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
# Use the first reference image if available
image_src = request.ref_images[0]
print(f"Using reference image: {image_src}", file=sys.stderr)
elif request.src != "":
# Fall back to request.src if no ref_images
image_src = request.src
print(f"Using source image: {image_src}", file=sys.stderr)
else:
print("No image source provided", file=sys.stderr)
if image_src and not self.controlnet and not self.img2vid:
image = Image.open(image_src)
if request.src != "" and not self.controlnet and not self.img2vid:
image = Image.open(request.src)
options["image"] = image
elif self.controlnet and image_src:
pose_image = load_image(image_src)
elif self.controlnet and request.src:
pose_image = load_image(request.src)
options["image"] = pose_image
if CLIPSKIP and self.clip_skip != 0:
@@ -554,11 +521,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if self.img2vid:
# Load the conditioning image
if image_src:
image = load_image(image_src)
else:
# Fallback to request.src for img2vid if no ref_images
image = load_image(request.src)
image = load_image(request.src)
image = image.resize((1024, 576))
generator = torch.manual_seed(request.seed)
@@ -595,96 +558,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Media generated", success=True)
def GenerateVideo(self, request, context):
try:
prompt = request.prompt
if not prompt:
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
# Set default values from request or use defaults
num_frames = request.num_frames if request.num_frames > 0 else 81
fps = request.fps if request.fps > 0 else 16
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
num_inference_steps = request.step if request.step > 0 else 40
# Prepare generation parameters
kwargs = {
"prompt": prompt,
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
"height": request.height if request.height > 0 else 720,
"width": request.width if request.width > 0 else 1280,
"num_frames": num_frames,
"guidance_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
}
# Add custom options from self.options (including guidance_scale_2 if specified)
kwargs.update(self.options)
# Set seed if provided
if request.seed > 0:
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
# Handle start and end images for video generation
if request.start_image:
kwargs["start_image"] = load_image(request.start_image)
if request.end_image:
kwargs["end_image"] = load_image(request.end_image)
print(f"Generating video with {kwargs=}", file=sys.stderr)
# Generate video frames based on pipeline type
if self.PipelineType == "WanPipeline":
# WAN2.2 text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0] # WAN2.2 returns frames in this format
elif self.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video generation
if request.start_image:
# Load and resize the input image according to WAN2.2 requirements
image = load_image(request.start_image)
# Use request dimensions or defaults, but respect WAN2.2 constraints
request_height = request.height if request.height > 0 else 480
request_width = request.width if request.width > 0 else 832
max_area = request_height * request_width
aspect_ratio = image.height / image.width
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
image = image.resize((width, height))
kwargs["image"] = image
kwargs["height"] = height
kwargs["width"] = width
output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.img2vid:
# Generic image-to-video generation
if request.start_image:
image = load_image(request.start_image)
image = image.resize((request.width if request.width > 0 else 1024,
request.height if request.height > 0 else 576))
kwargs["image"] = image
output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.txt2vid:
# Generic text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0]
else:
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
# Export video
export_to_video(frames, request.dst, fps=fps)
return backend_pb2.Result(message="Video generated successfully", success=True)
except Exception as err:
print(f"Error generating video: {err}", file=sys.stderr)
traceback.print_exc()
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),

View File

@@ -8,5 +8,4 @@ compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,12 +1,11 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.7.1+cu118
torchvision==0.22.1+cu118
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision==0.22.1
accelerate
compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,12 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.7.1
torchvision==0.22.1
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision
accelerate
compel
peft
sentencepiece
torch
ftfy
optimum-quanto
optimum-quanto

View File

@@ -8,5 +8,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -12,5 +12,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -8,5 +8,4 @@ peft
optimum-quanto
numpy<2
sentencepiece
torchvision
ftfy
torchvision

View File

@@ -7,5 +7,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,23 +0,0 @@
.PHONY: mlx-audio
mlx-audio:
bash install.sh
.PHONY: run
run: mlx-audio
@echo "Running mlx-audio..."
bash run.sh
@echo "mlx run."
.PHONY: test
test: mlx-audio
@echo "Testing mlx-audio..."
bash test.sh
@echo "mlx tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -1,459 +0,0 @@
#!/usr/bin/env python3
import asyncio
from concurrent import futures
import argparse
import signal
import sys
import os
import shutil
import glob
from typing import List
import time
import tempfile
import backend_pb2
import backend_pb2_grpc
import grpc
from mlx_audio.tts.utils import load_model
import soundfile as sf
import numpy as np
import uuid
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer that implements the Backend service defined in backend.proto.
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
"""
def _is_float(self, s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
Args:
request: The health check request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The health check reply.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
"""
Loads a TTS model using MLX-Audio.
Args:
request: The load model request.
context: The gRPC context.
Returns:
backend_pb2.Result: The load model result.
"""
try:
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
# Parse options like in the kokoro backend
options = request.Options
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We store all the options in a dict for later use
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
print(f"Options: {self.options}", file=sys.stderr)
# Load the model using MLX-Audio's load_model function
try:
self.tts_model = load_model(request.Model)
self.model_path = request.Model
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
except Exception as model_err:
print(f"Error loading TTS model: {model_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
except Exception as err:
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
def TTS(self, request, context):
"""
Generates TTS audio from text using MLX-Audio.
Args:
request: A TTSRequest object containing text, model, destination, voice, and language.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Result object indicating success or failure.
"""
try:
# Check if model is loaded
if not hasattr(self, 'tts_model') or self.tts_model is None:
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
# Handle speed parameter based on model type
speed_value = self._handle_speed_parameter(request, self.model_path)
# Map language names to codes if needed
lang_code = self._map_language_code(request.language, request.voice)
# Prepare generation parameters
gen_params = {
"text": request.text,
"speed": speed_value,
"verbose": False,
}
# Add model-specific parameters
if request.voice and request.voice.strip():
gen_params["voice"] = request.voice
# Check if model supports language codes (primarily Kokoro)
if "kokoro" in self.model_path.lower():
gen_params["lang_code"] = lang_code
# Add pitch and gender for Spark models
if "spark" in self.model_path.lower():
gen_params["pitch"] = 1.0 # Default to moderate
gen_params["gender"] = "female" # Default to female
print(f"Generation parameters: {gen_params}", file=sys.stderr)
# Generate audio using the loaded model
try:
results = self.tts_model.generate(**gen_params)
except Exception as gen_err:
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
# Process the generated audio segments
audio_arrays = []
for segment in results:
audio_arrays.append(segment.audio)
# If no segments, return error
if not audio_arrays:
print("No audio segments generated", file=sys.stderr)
return backend_pb2.Result(success=False, message="No audio generated")
# Concatenate all segments
cat_audio = np.concatenate(audio_arrays, axis=0)
# Generate output filename and path
if request.dst:
output_path = request.dst
else:
unique_id = str(uuid.uuid4())
filename = f"tts_{unique_id}.wav"
output_path = filename
# Write the audio as a WAV
try:
sf.write(output_path, cat_audio, 24000)
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
# Verify the file exists and has content
if not os.path.exists(output_path):
print(f"File was not created at {output_path}", file=sys.stderr)
return backend_pb2.Result(success=False, message="Failed to create audio file")
file_size = os.path.getsize(output_path)
if file_size == 0:
print("File was created but is empty", file=sys.stderr)
return backend_pb2.Result(success=False, message="Generated audio file is empty")
print(f"Audio file size: {file_size} bytes", file=sys.stderr)
except Exception as write_err:
print(f"Error writing audio file: {write_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
except Exception as e:
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
async def Predict(self, request, context):
"""
Generates TTS audio based on the given prompt using MLX-Audio TTS.
This is a fallback method for compatibility with the Predict endpoint.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
try:
# Check if model is loaded
if not hasattr(self, 'tts_model') or self.tts_model is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("TTS model not loaded. Please call LoadModel first.")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# For TTS, we expect the prompt to contain the text to synthesize
if not request.Prompt:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("Prompt is required for TTS generation")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# Handle speed parameter based on model type
speed_value = self._handle_speed_parameter(request, self.model_path)
# Map language names to codes if needed
lang_code = self._map_language_code(None, None) # Use defaults for Predict
# Prepare generation parameters
gen_params = {
"text": request.Prompt,
"speed": speed_value,
"verbose": False,
}
# Add model-specific parameters
if hasattr(self, 'options') and 'voice' in self.options:
gen_params["voice"] = self.options['voice']
# Check if model supports language codes (primarily Kokoro)
if "kokoro" in self.model_path.lower():
gen_params["lang_code"] = lang_code
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
# Generate audio using the loaded model
try:
results = self.tts_model.generate(**gen_params)
except Exception as gen_err:
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"TTS generation failed: {gen_err}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# Process the generated audio segments
audio_arrays = []
for segment in results:
audio_arrays.append(segment.audio)
# If no segments, return error
if not audio_arrays:
print("No audio segments generated", file=sys.stderr)
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
# Concatenate all segments
cat_audio = np.concatenate(audio_arrays, axis=0)
duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
# Return success message with audio information
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
except Exception as e:
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"TTS generation failed: {str(e)}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
def _handle_speed_parameter(self, request, model_path):
"""
Handle speed parameter based on model type.
Args:
request: The TTSRequest object.
model_path: The model path to determine model type.
Returns:
float: The processed speed value.
"""
# Get speed from options if available
speed = 1.0
if hasattr(self, 'options') and 'speed' in self.options:
speed = self.options['speed']
# Handle speed parameter based on model type
if "spark" in model_path.lower():
# Spark actually expects float values that map to speed descriptions
speed_map = {
"very_low": 0.0,
"low": 0.5,
"moderate": 1.0,
"high": 1.5,
"very_high": 2.0,
}
if isinstance(speed, str) and speed in speed_map:
speed_value = speed_map[speed]
else:
# Try to use as float, default to 1.0 (moderate) if invalid
try:
speed_value = float(speed)
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
speed_value = 1.0 # Default to moderate
except:
speed_value = 1.0 # Default to moderate
else:
# Other models use float speed values
try:
speed_value = float(speed)
if speed_value < 0.5 or speed_value > 2.0:
speed_value = 1.0 # Default to 1.0 if out of range
except ValueError:
speed_value = 1.0 # Default to 1.0 if invalid
return speed_value
def _map_language_code(self, language, voice):
"""
Map language names to codes if needed.
Args:
language: The language parameter from the request.
voice: The voice parameter from the request.
Returns:
str: The language code.
"""
if not language:
# Default to voice[0] if not found
return voice[0] if voice else "a"
# Map language names to codes if needed
language_map = {
"american_english": "a",
"british_english": "b",
"spanish": "e",
"french": "f",
"hindi": "h",
"italian": "i",
"portuguese": "p",
"japanese": "j",
"mandarin_chinese": "z",
# Also accept direct language codes
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
}
return language_map.get(language.lower(), language)
def _build_generation_params(self, request, default_speed=1.0):
"""
Build generation parameters from request attributes and options for MLX-Audio TTS.
Args:
request: The gRPC request.
default_speed: Default speed if not specified.
Returns:
dict: Generation parameters for MLX-Audio
"""
# Initialize generation parameters for MLX-Audio TTS
generation_params = {
'speed': default_speed,
'voice': 'af_heart', # Default voice
'lang_code': 'a', # Default language code
}
# Extract parameters from request attributes
if hasattr(request, 'Temperature') and request.Temperature > 0:
# Temperature could be mapped to speed variation
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
# Override with options if available
if hasattr(self, 'options'):
# Speed from options
if 'speed' in self.options:
generation_params['speed'] = self.options['speed']
# Voice from options
if 'voice' in self.options:
generation_params['voice'] = self.options['voice']
# Language code from options
if 'lang_code' in self.options:
generation_params['lang_code'] = self.options['lang_code']
# Model-specific parameters
param_option_mapping = {
'temp': 'speed',
'temperature': 'speed',
'top_p': 'speed', # Map top_p to speed variation
}
for option_key, param_key in param_option_mapping.items():
if option_key in self.options:
if param_key == 'speed':
# Ensure speed is within reasonable bounds
speed_val = float(self.options[option_key])
if 0.5 <= speed_val <= 2.0:
generation_params[param_key] = speed_val
return generation_params
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
asyncio.run(serve(args.addr))

View File

@@ -1,14 +0,0 @@
#!/bin/bash
set -e
USE_PIP=true
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
installRequirements

View File

@@ -1 +0,0 @@
git+https://github.com/Blaizzy/mlx-audio

View File

@@ -1,7 +0,0 @@
grpcio==1.71.0
protobuf
certifi
setuptools
mlx-audio
soundfile
numpy

View File

@@ -1,11 +0,0 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -1,142 +0,0 @@
import unittest
import subprocess
import time
import backend_pb2
import backend_pb2_grpc
import grpc
import unittest
import subprocess
import time
import grpc
import backend_pb2_grpc
import backend_pb2
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service.
This class contains methods to test the startup and shutdown of the gRPC service.
"""
def setUp(self):
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
time.sleep(10)
def tearDown(self) -> None:
self.service.terminate()
self.service.wait()
def test_server_startup(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
"""
This method tests if the TTS model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_tts_generation(self):
"""
This method tests if TTS audio is generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
# Test TTS generation
tts_req = backend_pb2.TTSRequest(
text="Hello, this is a test of the MLX-Audio TTS system.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_heart",
language="a"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()
def test_tts_with_options(self):
"""
This method tests if TTS works with various options and parameters
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(
Model="mlx-community/Kokoro-82M-4bit",
Options=["voice:af_soft", "speed:1.2", "lang_code:b"]
))
self.assertTrue(response.success)
# Test TTS generation with different voice and language
tts_req = backend_pb2.TTSRequest(
text="Hello, this is a test with British English accent.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_soft",
language="b"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("TTS with options service failed")
finally:
self.tearDown()
def test_tts_multilingual(self):
"""
This method tests if TTS works with different languages
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
# Test Spanish TTS
tts_req = backend_pb2.TTSRequest(
text="Hola, esto es una prueba del sistema TTS MLX-Audio.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_heart",
language="e"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("Multilingual TTS service failed")
finally:
self.tearDown()

View File

@@ -1,12 +0,0 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -40,6 +40,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except ValueError:
return False
def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
@@ -81,10 +89,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif self._is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

View File

@@ -38,6 +38,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except ValueError:
return False
def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
@@ -79,10 +87,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif self._is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

View File

@@ -1,3 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.7.1
llvmlite==0.43.0
numba==0.60.0

View File

@@ -0,0 +1,9 @@
torch==2.7.1
accelerate
llvmlite==0.43.0
numba==0.60.0
transformers
bitsandbytes
outetts
sentence-transformers==5.1.0
protobuf==6.32.0

View File

@@ -1,16 +0,0 @@
package main
import (
_ "embed"
"fyne.io/fyne/v2"
)
//go:embed logo.png
var logoData []byte
// resourceIconPng is the LocalAI logo icon
var resourceIconPng = &fyne.StaticResource{
StaticName: "logo.png",
StaticContent: logoData,
}

View File

@@ -1,866 +0,0 @@
package launcher
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
)
// Config represents the launcher configuration
type Config struct {
ModelsPath string `json:"models_path"`
BackendsPath string `json:"backends_path"`
Address string `json:"address"`
AutoStart bool `json:"auto_start"`
StartOnBoot bool `json:"start_on_boot"`
LogLevel string `json:"log_level"`
EnvironmentVars map[string]string `json:"environment_vars"`
ShowWelcome *bool `json:"show_welcome"`
}
// Launcher represents the main launcher application
type Launcher struct {
// Core components
releaseManager *ReleaseManager
config *Config
ui *LauncherUI
systray *SystrayManager
ctx context.Context
window fyne.Window
app fyne.App
// Process management
localaiCmd *exec.Cmd
isRunning bool
logBuffer *strings.Builder
logMutex sync.RWMutex
statusChannel chan string
// Logging
logFile *os.File
logPath string
// UI state
lastUpdateCheck time.Time
}
// NewLauncher creates a new launcher instance
func NewLauncher(ui *LauncherUI, window fyne.Window, app fyne.App) *Launcher {
return &Launcher{
releaseManager: NewReleaseManager(),
config: &Config{},
logBuffer: &strings.Builder{},
statusChannel: make(chan string, 100),
ctx: context.Background(),
ui: ui,
window: window,
app: app,
}
}
// setupLogging sets up log file for LocalAI process output
func (l *Launcher) setupLogging() error {
// Create logs directory in data folder
dataPath := l.GetDataPath()
logsDir := filepath.Join(dataPath, "logs")
if err := os.MkdirAll(logsDir, 0755); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
// Create log file with timestamp
timestamp := time.Now().Format("2006-01-02_15-04-05")
l.logPath = filepath.Join(logsDir, fmt.Sprintf("localai_%s.log", timestamp))
logFile, err := os.Create(l.logPath)
if err != nil {
return fmt.Errorf("failed to create log file: %w", err)
}
l.logFile = logFile
return nil
}
// Initialize sets up the launcher
func (l *Launcher) Initialize() error {
if l.app == nil {
return fmt.Errorf("app is nil")
}
log.Printf("Initializing launcher...")
// Setup logging
if err := l.setupLogging(); err != nil {
return fmt.Errorf("failed to setup logging: %w", err)
}
// Load configuration
log.Printf("Loading configuration...")
if err := l.loadConfig(); err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
log.Printf("Configuration loaded, current state: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
// Clean up any partial downloads
log.Printf("Cleaning up partial downloads...")
if err := l.releaseManager.CleanupPartialDownloads(); err != nil {
log.Printf("Warning: failed to cleanup partial downloads: %v", err)
}
if l.config.StartOnBoot {
l.StartLocalAI()
}
// Set default paths if not configured (only if not already loaded from config)
if l.config.ModelsPath == "" {
homeDir, _ := os.UserHomeDir()
l.config.ModelsPath = filepath.Join(homeDir, ".localai", "models")
log.Printf("Setting default ModelsPath: %s", l.config.ModelsPath)
}
if l.config.BackendsPath == "" {
homeDir, _ := os.UserHomeDir()
l.config.BackendsPath = filepath.Join(homeDir, ".localai", "backends")
log.Printf("Setting default BackendsPath: %s", l.config.BackendsPath)
}
if l.config.Address == "" {
l.config.Address = "127.0.0.1:8080"
log.Printf("Setting default Address: %s", l.config.Address)
}
if l.config.LogLevel == "" {
l.config.LogLevel = "info"
log.Printf("Setting default LogLevel: %s", l.config.LogLevel)
}
if l.config.EnvironmentVars == nil {
l.config.EnvironmentVars = make(map[string]string)
log.Printf("Initializing empty EnvironmentVars map")
}
// Set default welcome window preference
if l.config.ShowWelcome == nil {
true := true
l.config.ShowWelcome = &true
log.Printf("Setting default ShowWelcome: true")
}
// Create directories
os.MkdirAll(l.config.ModelsPath, 0755)
os.MkdirAll(l.config.BackendsPath, 0755)
// Save the configuration with default values
if err := l.saveConfig(); err != nil {
log.Printf("Warning: failed to save default configuration: %v", err)
}
// System tray is now handled in main.go using Fyne's built-in approach
// Check if LocalAI is installed
if !l.releaseManager.IsLocalAIInstalled() {
log.Printf("No LocalAI installation found")
fyne.Do(func() {
l.updateStatus("No LocalAI installation found")
if l.ui != nil {
// Show dialog offering to download LocalAI
l.showDownloadLocalAIDialog()
}
})
}
// Check for updates periodically
go l.periodicUpdateCheck()
return nil
}
// StartLocalAI starts the LocalAI server
func (l *Launcher) StartLocalAI() error {
if l.isRunning {
return fmt.Errorf("LocalAI is already running")
}
// Verify binary integrity before starting
if err := l.releaseManager.VerifyInstalledBinary(); err != nil {
// Binary is corrupted, remove it and offer to reinstall
binaryPath := l.releaseManager.GetBinaryPath()
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
return fmt.Errorf("LocalAI binary is corrupted: %v. Please reinstall LocalAI", err)
}
binaryPath := l.releaseManager.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return fmt.Errorf("LocalAI binary not found. Please download a release first")
}
// Build command arguments
args := []string{
"run",
"--models-path", l.config.ModelsPath,
"--backends-path", l.config.BackendsPath,
"--address", l.config.Address,
"--log-level", l.config.LogLevel,
}
l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...)
// Apply environment variables
if len(l.config.EnvironmentVars) > 0 {
env := os.Environ()
for key, value := range l.config.EnvironmentVars {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
l.localaiCmd.Env = env
}
// Setup logging
stdout, err := l.localaiCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderr, err := l.localaiCmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
// Start the process
if err := l.localaiCmd.Start(); err != nil {
return fmt.Errorf("failed to start LocalAI: %w", err)
}
l.isRunning = true
fyne.Do(func() {
l.updateStatus("LocalAI is starting...")
l.updateRunningState(true)
})
// Start log monitoring
go l.monitorLogs(stdout, "STDOUT")
go l.monitorLogs(stderr, "STDERR")
// Monitor process with startup timeout
go func() {
// Wait for process to start or fail
err := l.localaiCmd.Wait()
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
if err != nil {
l.updateStatus(fmt.Sprintf("LocalAI stopped with error: %v", err))
} else {
l.updateStatus("LocalAI stopped")
}
})
}()
// Add startup timeout detection
go func() {
time.Sleep(10 * time.Second) // Wait 10 seconds for startup
if l.isRunning {
// Check if process is still alive
if l.localaiCmd.Process != nil {
if err := l.localaiCmd.Process.Signal(syscall.Signal(0)); err != nil {
// Process is dead, mark as not running
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
l.updateStatus("LocalAI failed to start properly")
})
}
}
}
}()
return nil
}
// StopLocalAI stops the LocalAI server
func (l *Launcher) StopLocalAI() error {
if !l.isRunning || l.localaiCmd == nil {
return fmt.Errorf("LocalAI is not running")
}
// Gracefully terminate the process
if err := l.localaiCmd.Process.Signal(os.Interrupt); err != nil {
// If graceful termination fails, force kill
if killErr := l.localaiCmd.Process.Kill(); killErr != nil {
return fmt.Errorf("failed to kill LocalAI process: %w", killErr)
}
}
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
l.updateStatus("LocalAI stopped")
})
return nil
}
// IsRunning returns whether LocalAI is currently running
func (l *Launcher) IsRunning() bool {
return l.isRunning
}
// Shutdown performs cleanup when the application is closing
func (l *Launcher) Shutdown() error {
log.Printf("Launcher shutting down, stopping LocalAI...")
// Stop LocalAI if it's running
if l.isRunning {
if err := l.StopLocalAI(); err != nil {
log.Printf("Error stopping LocalAI during shutdown: %v", err)
}
}
// Close log file if open
if l.logFile != nil {
if err := l.logFile.Close(); err != nil {
log.Printf("Error closing log file: %v", err)
}
l.logFile = nil
}
log.Printf("Launcher shutdown complete")
return nil
}
// GetLogs returns the current log buffer
func (l *Launcher) GetLogs() string {
l.logMutex.RLock()
defer l.logMutex.RUnlock()
return l.logBuffer.String()
}
// GetRecentLogs returns the most recent logs (last 50 lines) for better error display
func (l *Launcher) GetRecentLogs() string {
l.logMutex.RLock()
defer l.logMutex.RUnlock()
content := l.logBuffer.String()
lines := strings.Split(content, "\n")
// Get last 50 lines
if len(lines) > 50 {
lines = lines[len(lines)-50:]
}
return strings.Join(lines, "\n")
}
// GetConfig returns the current configuration
func (l *Launcher) GetConfig() *Config {
return l.config
}
// SetConfig updates the configuration
func (l *Launcher) SetConfig(config *Config) error {
l.config = config
return l.saveConfig()
}
func (l *Launcher) GetUI() *LauncherUI {
return l.ui
}
func (l *Launcher) SetSystray(systray *SystrayManager) {
l.systray = systray
}
// GetReleaseManager returns the release manager
func (l *Launcher) GetReleaseManager() *ReleaseManager {
return l.releaseManager
}
// GetWebUIURL returns the URL for the WebUI
func (l *Launcher) GetWebUIURL() string {
address := l.config.Address
if strings.HasPrefix(address, ":") {
address = "localhost" + address
}
if !strings.HasPrefix(address, "http") {
address = "http://" + address
}
return address
}
// GetDataPath returns the path where LocalAI data and logs are stored
func (l *Launcher) GetDataPath() string {
// LocalAI typically stores data in the current working directory or a models directory
// First check if models path is configured
if l.config != nil && l.config.ModelsPath != "" {
// Return the parent directory of models path
return filepath.Dir(l.config.ModelsPath)
}
// Fallback to home directory LocalAI folder
homeDir, err := os.UserHomeDir()
if err != nil {
return "."
}
return filepath.Join(homeDir, ".localai")
}
// CheckForUpdates checks if there are any available updates
func (l *Launcher) CheckForUpdates() (bool, string, error) {
log.Printf("CheckForUpdates: checking for available updates...")
available, version, err := l.releaseManager.IsUpdateAvailable()
if err != nil {
log.Printf("CheckForUpdates: error occurred: %v", err)
return false, "", err
}
log.Printf("CheckForUpdates: result - available=%v, version=%s", available, version)
l.lastUpdateCheck = time.Now()
return available, version, nil
}
// DownloadUpdate downloads the latest version
func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
return l.releaseManager.DownloadRelease(version, progressCallback)
}
// GetCurrentVersion returns the current installed version
func (l *Launcher) GetCurrentVersion() string {
return l.releaseManager.GetInstalledVersion()
}
// GetCurrentStatus returns the current status
func (l *Launcher) GetCurrentStatus() string {
select {
case status := <-l.statusChannel:
return status
default:
if l.isRunning {
return "LocalAI is running"
}
return "Ready"
}
}
// GetLastStatus returns the last known status without consuming from channel
func (l *Launcher) GetLastStatus() string {
if l.isRunning {
return "LocalAI is running"
}
// Check if LocalAI is installed
if !l.releaseManager.IsLocalAIInstalled() {
return "LocalAI not installed"
}
return "Ready"
}
func (l *Launcher) githubReleaseNotesURL(version string) (*url.URL, error) {
// Construct GitHub release URL
releaseURL := fmt.Sprintf("https://github.com/%s/%s/releases/tag/%s",
l.releaseManager.GitHubOwner,
l.releaseManager.GitHubRepo,
version)
// Convert string to *url.URL
return url.Parse(releaseURL)
}
// showDownloadLocalAIDialog shows a dialog offering to download LocalAI
func (l *Launcher) showDownloadLocalAIDialog() {
if l.app == nil {
log.Printf("Cannot show download dialog: app is nil")
return
}
fyne.DoAndWait(func() {
// Create a standalone window for the download dialog
dialogWindow := l.app.NewWindow("LocalAI Installation Required")
dialogWindow.Resize(fyne.NewSize(500, 350))
dialogWindow.CenterOnScreen()
dialogWindow.SetCloseIntercept(func() {
dialogWindow.Close()
})
// Create the dialog content
titleLabel := widget.NewLabel("LocalAI Not Found")
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
messageLabel := widget.NewLabel("LocalAI is not installed on your system.\n\nWould you like to download and install the latest version?")
messageLabel.Wrapping = fyne.TextWrapWord
messageLabel.Alignment = fyne.TextAlignCenter
// Buttons
downloadButton := widget.NewButton("Download & Install", func() {
dialogWindow.Close()
l.downloadAndInstallLocalAI()
if l.systray != nil {
l.systray.recreateMenu()
}
})
downloadButton.Importance = widget.HighImportance
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
// Get latest release info and open release notes
go func() {
release, err := l.releaseManager.GetLatestRelease()
if err != nil {
log.Printf("Failed to get latest release info: %v", err)
return
}
releaseNotesURL, err := l.githubReleaseNotesURL(release.Version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
l.app.OpenURL(releaseNotesURL)
}()
})
skipButton := widget.NewButton("Skip for Now", func() {
dialogWindow.Close()
})
// Layout - put release notes button above the main action buttons
actionButtons := container.NewHBox(skipButton, downloadButton)
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
messageLabel,
widget.NewSeparator(),
releaseNotesButton,
widget.NewSeparator(),
actionButtons,
)
dialogWindow.SetContent(content)
dialogWindow.Show()
})
}
// downloadAndInstallLocalAI downloads and installs the latest LocalAI version
func (l *Launcher) downloadAndInstallLocalAI() {
if l.app == nil {
log.Printf("Cannot download LocalAI: app is nil")
return
}
// First check what the latest version is
go func() {
log.Printf("Checking for latest LocalAI version...")
available, version, err := l.CheckForUpdates()
if err != nil {
log.Printf("Failed to check for updates: %v", err)
l.showDownloadError("Failed to check for latest version", err.Error())
return
}
if !available {
log.Printf("No updates available, but LocalAI is not installed")
l.showDownloadError("No Version Available", "Could not determine the latest LocalAI version. Please check your internet connection and try again.")
return
}
log.Printf("Latest version available: %s", version)
// Show progress window with the specific version
l.showDownloadProgress(version, fmt.Sprintf("Downloading LocalAI %s...", version))
}()
}
// showDownloadError shows an error dialog for download failures
func (l *Launcher) showDownloadError(title, message string) {
fyne.DoAndWait(func() {
// Create error window
errorWindow := l.app.NewWindow("Download Error")
errorWindow.Resize(fyne.NewSize(400, 200))
errorWindow.CenterOnScreen()
errorWindow.SetCloseIntercept(func() {
errorWindow.Close()
})
// Error content
titleLabel := widget.NewLabel(title)
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
messageLabel := widget.NewLabel(message)
messageLabel.Wrapping = fyne.TextWrapWord
messageLabel.Alignment = fyne.TextAlignCenter
// Close button
closeButton := widget.NewButton("Close", func() {
errorWindow.Close()
})
// Layout
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
messageLabel,
widget.NewSeparator(),
closeButton,
)
errorWindow.SetContent(content)
errorWindow.Show()
})
}
// showDownloadProgress shows a standalone progress window for downloading LocalAI
func (l *Launcher) showDownloadProgress(version, title string) {
fyne.DoAndWait(func() {
// Create progress window
progressWindow := l.app.NewWindow("Downloading LocalAI")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
progressWindow.SetCloseIntercept(func() {
progressWindow.Close()
})
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := l.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
l.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(title),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := l.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show success dialog
dialog.ShowConfirm("Installation Complete",
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
func(close bool) {
progressWindow.Close()
// Update status and refresh systray menu
l.updateStatus("LocalAI installed successfully")
if l.systray != nil {
l.systray.recreateMenu()
}
}, progressWindow)
}
})
}()
})
}
// monitorLogs monitors the output of LocalAI and adds it to the log buffer
func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
timestamp := time.Now().Format("15:04:05")
logLine := fmt.Sprintf("[%s] %s: %s\n", timestamp, prefix, line)
l.logMutex.Lock()
l.logBuffer.WriteString(logLine)
// Keep log buffer size reasonable
if l.logBuffer.Len() > 100000 { // 100KB
content := l.logBuffer.String()
// Keep last 50KB
if len(content) > 50000 {
l.logBuffer.Reset()
l.logBuffer.WriteString(content[len(content)-50000:])
}
}
l.logMutex.Unlock()
// Write to log file if available
if l.logFile != nil {
if _, err := l.logFile.WriteString(logLine); err != nil {
log.Printf("Failed to write to log file: %v", err)
}
}
fyne.Do(func() {
// Notify UI of new log content
if l.ui != nil {
l.ui.OnLogUpdate(logLine)
}
// Check for startup completion
if strings.Contains(line, "API server listening") {
l.updateStatus("LocalAI is running")
}
})
}
}
// updateStatus updates the status and notifies UI
func (l *Launcher) updateStatus(status string) {
select {
case l.statusChannel <- status:
default:
// Channel full, skip
}
if l.ui != nil {
l.ui.UpdateStatus(status)
}
if l.systray != nil {
l.systray.UpdateStatus(status)
}
}
// updateRunningState updates the running state in UI and systray
func (l *Launcher) updateRunningState(isRunning bool) {
if l.ui != nil {
l.ui.UpdateRunningState(isRunning)
}
if l.systray != nil {
l.systray.UpdateRunningState(isRunning)
}
}
// periodicUpdateCheck checks for updates periodically
func (l *Launcher) periodicUpdateCheck() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
available, version, err := l.CheckForUpdates()
if err == nil && available {
fyne.Do(func() {
l.updateStatus(fmt.Sprintf("Update available: %s", version))
if l.systray != nil {
l.systray.NotifyUpdateAvailable(version)
}
if l.ui != nil {
l.ui.NotifyUpdateAvailable(version)
}
})
}
case <-l.ctx.Done():
return
}
}
}
// loadConfig loads configuration from file
func (l *Launcher) loadConfig() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
configPath := filepath.Join(homeDir, ".localai", "launcher.json")
log.Printf("Loading config from: %s", configPath)
if _, err := os.Stat(configPath); os.IsNotExist(err) {
log.Printf("Config file not found, creating default config")
// Create default config
return l.saveConfig()
}
// Load existing config
configData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
log.Printf("Config file content: %s", string(configData))
log.Printf("loadConfig: about to unmarshal JSON data")
if err := json.Unmarshal(configData, l.config); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
log.Printf("loadConfig: JSON unmarshaled successfully")
log.Printf("Loaded config: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
log.Printf("Environment vars: %v", l.config.EnvironmentVars)
return nil
}
// saveConfig saves configuration to file
func (l *Launcher) saveConfig() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
configDir := filepath.Join(homeDir, ".localai")
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Marshal config to JSON
log.Printf("saveConfig: marshaling config with EnvironmentVars: %v", l.config.EnvironmentVars)
configData, err := json.MarshalIndent(l.config, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
log.Printf("saveConfig: JSON marshaled successfully, length: %d", len(configData))
configPath := filepath.Join(configDir, "launcher.json")
log.Printf("Saving config to: %s", configPath)
log.Printf("Config content: %s", string(configData))
if err := os.WriteFile(configPath, configData, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
log.Printf("Config saved successfully")
return nil
}

View File

@@ -1,13 +0,0 @@
package launcher_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLauncher(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Launcher Suite")
}

View File

@@ -1,213 +0,0 @@
package launcher_test
import (
"os"
"path/filepath"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"fyne.io/fyne/v2/app"
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
var _ = Describe("Launcher", func() {
var (
launcherInstance *launcher.Launcher
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "launcher-test-*")
Expect(err).ToNot(HaveOccurred())
ui := launcher.NewLauncherUI()
app := app.NewWithID("com.localai.launcher")
launcherInstance = launcher.NewLauncher(ui, nil, app)
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
Describe("NewLauncher", func() {
It("should create a launcher with default configuration", func() {
Expect(launcherInstance.GetConfig()).ToNot(BeNil())
})
})
Describe("Initialize", func() {
It("should set default paths when not configured", func() {
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
config := launcherInstance.GetConfig()
Expect(config.ModelsPath).ToNot(BeEmpty())
Expect(config.BackendsPath).ToNot(BeEmpty())
})
It("should set default ShowWelcome to true", func() {
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
config := launcherInstance.GetConfig()
Expect(config.ShowWelcome).To(BeTrue())
Expect(config.Address).To(Equal("127.0.0.1:8080"))
Expect(config.LogLevel).To(Equal("info"))
})
It("should create models and backends directories", func() {
// Set custom paths for testing
config := launcherInstance.GetConfig()
config.ModelsPath = filepath.Join(tempDir, "models")
config.BackendsPath = filepath.Join(tempDir, "backends")
launcherInstance.SetConfig(config)
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
// Check if directories were created
_, err = os.Stat(config.ModelsPath)
Expect(err).ToNot(HaveOccurred())
_, err = os.Stat(config.BackendsPath)
Expect(err).ToNot(HaveOccurred())
})
})
Describe("Configuration", func() {
It("should get and set configuration", func() {
config := launcherInstance.GetConfig()
config.ModelsPath = "/test/models"
config.BackendsPath = "/test/backends"
config.Address = ":9090"
config.LogLevel = "debug"
err := launcherInstance.SetConfig(config)
Expect(err).ToNot(HaveOccurred())
retrievedConfig := launcherInstance.GetConfig()
Expect(retrievedConfig.ModelsPath).To(Equal("/test/models"))
Expect(retrievedConfig.BackendsPath).To(Equal("/test/backends"))
Expect(retrievedConfig.Address).To(Equal(":9090"))
Expect(retrievedConfig.LogLevel).To(Equal("debug"))
})
})
Describe("WebUI URL", func() {
It("should return correct WebUI URL for localhost", func() {
config := launcherInstance.GetConfig()
config.Address = ":8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://localhost:8080"))
})
It("should return correct WebUI URL for full address", func() {
config := launcherInstance.GetConfig()
config.Address = "127.0.0.1:8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://127.0.0.1:8080"))
})
It("should handle http prefix correctly", func() {
config := launcherInstance.GetConfig()
config.Address = "http://localhost:8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://localhost:8080"))
})
})
Describe("Process Management", func() {
It("should not be running initially", func() {
Expect(launcherInstance.IsRunning()).To(BeFalse())
})
It("should handle start when binary doesn't exist", func() {
err := launcherInstance.StartLocalAI()
Expect(err).To(HaveOccurred())
// Could be either "not found" or "permission denied" depending on test environment
errMsg := err.Error()
hasExpectedError := strings.Contains(errMsg, "LocalAI binary") ||
strings.Contains(errMsg, "permission denied")
Expect(hasExpectedError).To(BeTrue(), "Expected error about binary not found or permission denied, got: %s", errMsg)
})
It("should handle stop when not running", func() {
err := launcherInstance.StopLocalAI()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LocalAI is not running"))
})
})
Describe("Logs", func() {
It("should return empty logs initially", func() {
logs := launcherInstance.GetLogs()
Expect(logs).To(BeEmpty())
})
})
Describe("Version Management", func() {
It("should return empty version when no binary installed", func() {
version := launcherInstance.GetCurrentVersion()
Expect(version).To(BeEmpty()) // No binary installed in test environment
})
It("should handle update checks", func() {
// This test would require mocking HTTP responses
// For now, we'll just test that the method doesn't panic
_, _, err := launcherInstance.CheckForUpdates()
// We expect either success or a network error, not a panic
if err != nil {
// Network error is acceptable in tests
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
}
})
})
})
var _ = Describe("Config", func() {
It("should have proper JSON tags", func() {
config := &launcher.Config{
ModelsPath: "/test/models",
BackendsPath: "/test/backends",
Address: ":8080",
AutoStart: true,
LogLevel: "info",
EnvironmentVars: map[string]string{"TEST": "value"},
}
Expect(config.ModelsPath).To(Equal("/test/models"))
Expect(config.BackendsPath).To(Equal("/test/backends"))
Expect(config.Address).To(Equal(":8080"))
Expect(config.AutoStart).To(BeTrue())
Expect(config.LogLevel).To(Equal("info"))
Expect(config.EnvironmentVars).To(HaveKeyWithValue("TEST", "value"))
})
It("should initialize environment variables map", func() {
config := &launcher.Config{}
Expect(config.EnvironmentVars).To(BeNil())
ui := launcher.NewLauncherUI()
app := app.NewWithID("com.localai.launcher")
launcher := launcher.NewLauncher(ui, nil, app)
err := launcher.Initialize()
Expect(err).ToNot(HaveOccurred())
retrievedConfig := launcher.GetConfig()
Expect(retrievedConfig.EnvironmentVars).ToNot(BeNil())
Expect(retrievedConfig.EnvironmentVars).To(BeEmpty())
})
})

View File

@@ -1,502 +0,0 @@
package launcher
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/mudler/LocalAI/internal"
)
// Release represents a LocalAI release
type Release struct {
Version string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
PublishedAt time.Time `json:"published_at"`
Assets []Asset `json:"assets"`
}
// Asset represents a release asset
type Asset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
}
// ReleaseManager handles LocalAI release management
type ReleaseManager struct {
// GitHubOwner is the GitHub repository owner
GitHubOwner string
// GitHubRepo is the GitHub repository name
GitHubRepo string
// BinaryPath is where the LocalAI binary is stored locally
BinaryPath string
// CurrentVersion is the currently installed version
CurrentVersion string
// ChecksumsPath is where checksums are stored
ChecksumsPath string
// MetadataPath is where version metadata is stored
MetadataPath string
}
// NewReleaseManager creates a new release manager
func NewReleaseManager() *ReleaseManager {
homeDir, _ := os.UserHomeDir()
binaryPath := filepath.Join(homeDir, ".localai", "bin")
checksumsPath := filepath.Join(homeDir, ".localai", "checksums")
metadataPath := filepath.Join(homeDir, ".localai", "metadata")
return &ReleaseManager{
GitHubOwner: "mudler",
GitHubRepo: "LocalAI",
BinaryPath: binaryPath,
CurrentVersion: internal.PrintableVersion(),
ChecksumsPath: checksumsPath,
MetadataPath: metadataPath,
}
}
// GetLatestRelease fetches the latest release information from GitHub
func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
}
// Parse the JSON response properly
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
release := &Release{}
if err := json.Unmarshal(body, release); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
// Validate the release data
if release.Version == "" {
return nil, fmt.Errorf("no version found in release data")
}
return release, nil
}
// DownloadRelease downloads a specific version of LocalAI
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
// Ensure the binary directory exists
if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
return fmt.Errorf("failed to create binary directory: %w", err)
}
// Determine the binary name based on OS and architecture
binaryName := rm.GetBinaryName(version)
localPath := filepath.Join(rm.BinaryPath, "local-ai")
// Download the binary
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
return fmt.Errorf("failed to download binary: %w", err)
}
// Download and verify checksums
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
rm.GitHubOwner, rm.GitHubRepo, version, version)
checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
if err := rm.downloadFile(checksumURL, checksumPath, nil); err != nil {
return fmt.Errorf("failed to download checksums: %w", err)
}
// Verify the checksum
if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
// Save checksums persistently for future verification
if err := rm.saveChecksums(version, checksumPath, binaryName); err != nil {
log.Printf("Warning: failed to save checksums: %v", err)
}
// Make the binary executable
if err := os.Chmod(localPath, 0755); err != nil {
return fmt.Errorf("failed to make binary executable: %w", err)
}
return nil
}
// GetBinaryName returns the appropriate binary name for the current platform
func (rm *ReleaseManager) GetBinaryName(version string) string {
versionStr := strings.TrimPrefix(version, "v")
os := runtime.GOOS
arch := runtime.GOARCH
// Map Go arch names to the release naming convention
switch arch {
case "amd64":
arch = "amd64"
case "arm64":
arch = "arm64"
default:
arch = "amd64" // fallback
}
return fmt.Sprintf("local-ai-v%s-%s-%s", versionStr, os, arch)
}
// downloadFile downloads a file from a URL to a local path with optional progress callback
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
}
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
// Create a progress reader if callback is provided
var reader io.Reader = resp.Body
if progressCallback != nil && resp.ContentLength > 0 {
reader = &progressReader{
Reader: resp.Body,
Total: resp.ContentLength,
Callback: progressCallback,
}
}
_, err = io.Copy(out, reader)
return err
}
// saveChecksums saves checksums persistently for future verification
func (rm *ReleaseManager) saveChecksums(version, checksumPath, binaryName string) error {
// Ensure checksums directory exists
if err := os.MkdirAll(rm.ChecksumsPath, 0755); err != nil {
return fmt.Errorf("failed to create checksums directory: %w", err)
}
// Read the downloaded checksums file
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return fmt.Errorf("failed to read checksums file: %w", err)
}
// Save to persistent location with version info
persistentPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
if err := os.WriteFile(persistentPath, checksumData, 0644); err != nil {
return fmt.Errorf("failed to write persistent checksums: %w", err)
}
// Also save a "latest" checksums file for the current version
latestPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
if err := os.WriteFile(latestPath, checksumData, 0644); err != nil {
return fmt.Errorf("failed to write latest checksums: %w", err)
}
// Save version metadata
if err := rm.saveVersionMetadata(version); err != nil {
log.Printf("Warning: failed to save version metadata: %v", err)
}
log.Printf("Checksums saved for version %s", version)
return nil
}
// saveVersionMetadata saves the installed version information
func (rm *ReleaseManager) saveVersionMetadata(version string) error {
// Ensure metadata directory exists
if err := os.MkdirAll(rm.MetadataPath, 0755); err != nil {
return fmt.Errorf("failed to create metadata directory: %w", err)
}
// Create metadata structure
metadata := struct {
Version string `json:"version"`
InstalledAt time.Time `json:"installed_at"`
BinaryPath string `json:"binary_path"`
}{
Version: version,
InstalledAt: time.Now(),
BinaryPath: rm.GetBinaryPath(),
}
// Marshal to JSON
metadataData, err := json.MarshalIndent(metadata, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
// Save metadata file
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
if err := os.WriteFile(metadataPath, metadataData, 0644); err != nil {
return fmt.Errorf("failed to write metadata file: %w", err)
}
log.Printf("Version metadata saved: %s", version)
return nil
}
// progressReader wraps an io.Reader to provide download progress
type progressReader struct {
io.Reader
Total int64
Current int64
Callback func(float64)
}
func (pr *progressReader) Read(p []byte) (int, error) {
n, err := pr.Reader.Read(p)
pr.Current += int64(n)
if pr.Callback != nil {
progress := float64(pr.Current) / float64(pr.Total)
pr.Callback(progress)
}
return n, err
}
// VerifyChecksum verifies the downloaded file against the provided checksums
func (rm *ReleaseManager) VerifyChecksum(filePath, checksumPath, binaryName string) error {
// Calculate the SHA256 of the downloaded file
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file for checksum: %w", err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return fmt.Errorf("failed to calculate checksum: %w", err)
}
calculatedHash := hex.EncodeToString(hasher.Sum(nil))
// Read the checksums file
checksumFile, err := os.Open(checksumPath)
if err != nil {
return fmt.Errorf("failed to open checksums file: %w", err)
}
defer checksumFile.Close()
scanner := bufio.NewScanner(checksumFile)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, binaryName) {
parts := strings.Fields(line)
if len(parts) >= 2 {
expectedHash := parts[0]
if calculatedHash == expectedHash {
return nil // Checksum verified
}
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, calculatedHash)
}
}
}
return fmt.Errorf("checksum not found for %s", binaryName)
}
// GetInstalledVersion returns the currently installed version
func (rm *ReleaseManager) GetInstalledVersion() string {
// Fallback: Check if the LocalAI binary exists and try to get its version
binaryPath := rm.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return "" // No version installed
}
// try to get version from metadata
if version := rm.loadVersionMetadata(); version != "" {
return version
}
// Try to run the binary to get the version (fallback method)
version, err := exec.Command(binaryPath, "--version").Output()
if err != nil {
// If binary exists but --version fails, try to determine from filename or other means
log.Printf("Binary exists but --version failed: %v", err)
return ""
}
stringVersion := strings.TrimSpace(string(version))
stringVersion = strings.TrimRight(stringVersion, "\n")
return stringVersion
}
// loadVersionMetadata loads the installed version from metadata file
func (rm *ReleaseManager) loadVersionMetadata() string {
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
// Check if metadata file exists
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
return ""
}
// Read metadata file
metadataData, err := os.ReadFile(metadataPath)
if err != nil {
log.Printf("Failed to read metadata file: %v", err)
return ""
}
// Parse metadata
var metadata struct {
Version string `json:"version"`
InstalledAt time.Time `json:"installed_at"`
BinaryPath string `json:"binary_path"`
}
if err := json.Unmarshal(metadataData, &metadata); err != nil {
log.Printf("Failed to parse metadata file: %v", err)
return ""
}
// Verify that the binary path in metadata matches current binary path
if metadata.BinaryPath != rm.GetBinaryPath() {
log.Printf("Binary path mismatch in metadata, ignoring")
return ""
}
log.Printf("Loaded version from metadata: %s (installed at %s)", metadata.Version, metadata.InstalledAt.Format("2006-01-02 15:04:05"))
return metadata.Version
}
// GetBinaryPath returns the path to the LocalAI binary
func (rm *ReleaseManager) GetBinaryPath() string {
return filepath.Join(rm.BinaryPath, "local-ai")
}
// IsUpdateAvailable checks if an update is available
func (rm *ReleaseManager) IsUpdateAvailable() (bool, string, error) {
log.Printf("IsUpdateAvailable: checking for updates...")
latest, err := rm.GetLatestRelease()
if err != nil {
log.Printf("IsUpdateAvailable: failed to get latest release: %v", err)
return false, "", err
}
log.Printf("IsUpdateAvailable: latest release version: %s", latest.Version)
current := rm.GetInstalledVersion()
log.Printf("IsUpdateAvailable: current installed version: %s", current)
if current == "" {
// No version installed, offer to download latest
log.Printf("IsUpdateAvailable: no version installed, offering latest: %s", latest.Version)
return true, latest.Version, nil
}
updateAvailable := latest.Version != current
log.Printf("IsUpdateAvailable: update available: %v (latest: %s, current: %s)", updateAvailable, latest.Version, current)
return updateAvailable, latest.Version, nil
}
// IsLocalAIInstalled checks if LocalAI binary exists and is valid
func (rm *ReleaseManager) IsLocalAIInstalled() bool {
binaryPath := rm.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return false
}
// Verify the binary integrity
if err := rm.VerifyInstalledBinary(); err != nil {
log.Printf("Binary integrity check failed: %v", err)
// Remove corrupted binary
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
return false
}
return true
}
// VerifyInstalledBinary verifies the installed binary against saved checksums
func (rm *ReleaseManager) VerifyInstalledBinary() error {
binaryPath := rm.GetBinaryPath()
// Check if we have saved checksums
latestChecksumsPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
if _, err := os.Stat(latestChecksumsPath); os.IsNotExist(err) {
return fmt.Errorf("no saved checksums found")
}
// Get the binary name for the current version from metadata
currentVersion := rm.loadVersionMetadata()
if currentVersion == "" {
return fmt.Errorf("cannot determine current version from metadata")
}
binaryName := rm.GetBinaryName(currentVersion)
// Verify against saved checksums
return rm.VerifyChecksum(binaryPath, latestChecksumsPath, binaryName)
}
// CleanupPartialDownloads removes any partial or corrupted downloads
func (rm *ReleaseManager) CleanupPartialDownloads() error {
binaryPath := rm.GetBinaryPath()
// Check if binary exists but is corrupted
if _, err := os.Stat(binaryPath); err == nil {
// Binary exists, verify it
if verifyErr := rm.VerifyInstalledBinary(); verifyErr != nil {
log.Printf("Found corrupted binary, removing: %v", verifyErr)
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
// Clear metadata since binary is corrupted
rm.clearVersionMetadata()
}
}
// Clean up any temporary checksum files
tempChecksumsPath := filepath.Join(rm.BinaryPath, "checksums.txt")
if _, err := os.Stat(tempChecksumsPath); err == nil {
if removeErr := os.Remove(tempChecksumsPath); removeErr != nil {
log.Printf("Failed to remove temporary checksums: %v", removeErr)
}
}
return nil
}
// clearVersionMetadata clears the version metadata (used when binary is corrupted or removed)
func (rm *ReleaseManager) clearVersionMetadata() {
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
log.Printf("Failed to clear version metadata: %v", err)
} else {
log.Printf("Version metadata cleared")
}
}

View File

@@ -1,178 +0,0 @@
package launcher_test
import (
"os"
"path/filepath"
"runtime"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
var _ = Describe("ReleaseManager", func() {
var (
rm *launcher.ReleaseManager
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "launcher-test-*")
Expect(err).ToNot(HaveOccurred())
rm = launcher.NewReleaseManager()
// Override binary path for testing
rm.BinaryPath = tempDir
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
Describe("NewReleaseManager", func() {
It("should create a release manager with correct defaults", func() {
newRM := launcher.NewReleaseManager()
Expect(newRM.GitHubOwner).To(Equal("mudler"))
Expect(newRM.GitHubRepo).To(Equal("LocalAI"))
Expect(newRM.BinaryPath).To(ContainSubstring(".localai"))
})
})
Describe("GetBinaryName", func() {
It("should return correct binary name for current platform", func() {
binaryName := rm.GetBinaryName("v3.4.0")
expectedOS := runtime.GOOS
expectedArch := runtime.GOARCH
expected := "local-ai-v3.4.0-" + expectedOS + "-" + expectedArch
Expect(binaryName).To(Equal(expected))
})
It("should handle version with and without 'v' prefix", func() {
withV := rm.GetBinaryName("v3.4.0")
withoutV := rm.GetBinaryName("3.4.0")
// Both should produce the same result
Expect(withV).To(Equal(withoutV))
})
})
Describe("GetBinaryPath", func() {
It("should return the correct binary path", func() {
path := rm.GetBinaryPath()
expected := filepath.Join(tempDir, "local-ai")
Expect(path).To(Equal(expected))
})
})
Describe("GetInstalledVersion", func() {
It("should return empty when no binary exists", func() {
version := rm.GetInstalledVersion()
Expect(version).To(BeEmpty()) // No binary installed in test
})
It("should return empty version when binary exists but no metadata", func() {
// Create a fake binary for testing
err := os.MkdirAll(rm.BinaryPath, 0755)
Expect(err).ToNot(HaveOccurred())
binaryPath := rm.GetBinaryPath()
err = os.WriteFile(binaryPath, []byte("fake binary"), 0755)
Expect(err).ToNot(HaveOccurred())
version := rm.GetInstalledVersion()
Expect(version).To(BeEmpty())
})
})
Context("with mocked responses", func() {
// Note: In a real implementation, we'd mock HTTP responses
// For now, we'll test the structure and error handling
Describe("GetLatestRelease", func() {
It("should handle network errors gracefully", func() {
// This test would require mocking HTTP client
// For demonstration, we're just testing the method exists
_, err := rm.GetLatestRelease()
// We expect either success or a network error, not a panic
// In a real test, we'd mock the HTTP response
if err != nil {
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
}
})
})
Describe("DownloadRelease", func() {
It("should create binary directory if it doesn't exist", func() {
// Remove the temp directory to test creation
os.RemoveAll(tempDir)
// This will fail due to network, but should create the directory
rm.DownloadRelease("v3.4.0", nil)
// Check if directory was created
_, err := os.Stat(tempDir)
Expect(err).ToNot(HaveOccurred())
})
})
})
Describe("VerifyChecksum functionality", func() {
var (
testFile string
checksumFile string
)
BeforeEach(func() {
testFile = filepath.Join(tempDir, "test-binary")
checksumFile = filepath.Join(tempDir, "checksums.txt")
})
It("should verify checksums correctly", func() {
// Create a test file with known content
testContent := []byte("test content for checksum")
err := os.WriteFile(testFile, testContent, 0644)
Expect(err).ToNot(HaveOccurred())
// Calculate expected SHA256
// This is a simplified test - in practice we'd use the actual checksum
checksumContent := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 test-binary\n"
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
Expect(err).ToNot(HaveOccurred())
// Test checksum verification
// Note: This will fail because our content doesn't match the empty string hash
// In a real test, we'd calculate the actual hash
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
// We expect this to fail since we're using a dummy checksum
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("checksum mismatch"))
})
It("should handle missing checksum file", func() {
// Create test file but no checksum file
err := os.WriteFile(testFile, []byte("test"), 0644)
Expect(err).ToNot(HaveOccurred())
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to open checksums file"))
})
It("should handle missing binary in checksums", func() {
// Create files but checksum doesn't contain our binary
err := os.WriteFile(testFile, []byte("test"), 0644)
Expect(err).ToNot(HaveOccurred())
checksumContent := "hash other-binary\n"
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
Expect(err).ToNot(HaveOccurred())
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("checksum not found"))
})
})
})

View File

@@ -1,523 +0,0 @@
package launcher
import (
"fmt"
"log"
"net/url"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/driver/desktop"
"fyne.io/fyne/v2/widget"
)
// SystrayManager manages the system tray functionality
type SystrayManager struct {
launcher *Launcher
window fyne.Window
app fyne.App
desk desktop.App
// Menu items that need dynamic updates
startStopItem *fyne.MenuItem
hasUpdateAvailable bool
latestVersion string
icon *fyne.StaticResource
}
// NewSystrayManager creates a new systray manager
func NewSystrayManager(launcher *Launcher, window fyne.Window, desktop desktop.App, app fyne.App, icon *fyne.StaticResource) *SystrayManager {
sm := &SystrayManager{
launcher: launcher,
window: window,
app: app,
desk: desktop,
icon: icon,
}
sm.setupMenu(desktop)
return sm
}
// setupMenu sets up the system tray menu
func (sm *SystrayManager) setupMenu(desk desktop.App) {
sm.desk = desk
// Create the start/stop toggle item
sm.startStopItem = fyne.NewMenuItem("Start LocalAI", func() {
sm.toggleLocalAI()
})
desk.SetSystemTrayIcon(sm.icon)
// Initialize the menu state using recreateMenu
sm.recreateMenu()
}
// toggleLocalAI starts or stops LocalAI based on current state
func (sm *SystrayManager) toggleLocalAI() {
if sm.launcher.IsRunning() {
go func() {
if err := sm.launcher.StopLocalAI(); err != nil {
log.Printf("Failed to stop LocalAI: %v", err)
sm.showErrorDialog("Failed to Stop LocalAI", err.Error())
}
}()
} else {
go func() {
if err := sm.launcher.StartLocalAI(); err != nil {
log.Printf("Failed to start LocalAI: %v", err)
sm.showStartupErrorDialog(err)
}
}()
}
}
// openWebUI opens the LocalAI WebUI in the default browser
func (sm *SystrayManager) openWebUI() {
if !sm.launcher.IsRunning() {
return // LocalAI is not running
}
webURL := sm.launcher.GetWebUIURL()
if parsedURL, err := url.Parse(webURL); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// openDocumentation opens the LocalAI documentation
func (sm *SystrayManager) openDocumentation() {
if parsedURL, err := url.Parse("https://localai.io"); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// updateStartStopItem updates the start/stop menu item based on current state
func (sm *SystrayManager) updateStartStopItem() {
// Since Fyne menu items can't change text dynamically, we recreate the menu
sm.recreateMenu()
}
// recreateMenu recreates the entire menu with updated state
func (sm *SystrayManager) recreateMenu() {
if sm.desk == nil {
return
}
// Determine the action based on LocalAI installation and running state
var actionItem *fyne.MenuItem
if !sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
// LocalAI not installed - show install option
actionItem = fyne.NewMenuItem("📥 Install Latest Version", func() {
sm.launcher.showDownloadLocalAIDialog()
})
} else if sm.launcher.IsRunning() {
// LocalAI is running - show stop option
actionItem = fyne.NewMenuItem("🛑 Stop LocalAI", func() {
sm.toggleLocalAI()
})
} else {
// LocalAI is installed but not running - show start option
actionItem = fyne.NewMenuItem("▶️ Start LocalAI", func() {
sm.toggleLocalAI()
})
}
menuItems := []*fyne.MenuItem{}
// Add status at the top (clickable for details)
status := sm.launcher.GetLastStatus()
statusText := sm.truncateText(status, 30)
statusItem := fyne.NewMenuItem("📊 Status: "+statusText, func() {
sm.showStatusDetails(status, "")
})
menuItems = append(menuItems, statusItem)
// Only show version if LocalAI is installed
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
version := sm.launcher.GetCurrentVersion()
versionText := sm.truncateText(version, 25)
versionItem := fyne.NewMenuItem("🔧 Version: "+versionText, func() {
sm.showStatusDetails(status, version)
})
menuItems = append(menuItems, versionItem)
}
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
// Add update notification if available
if sm.hasUpdateAvailable {
updateItem := fyne.NewMenuItem("🔔 New version available ("+sm.latestVersion+")", func() {
sm.downloadUpdate()
})
menuItems = append(menuItems, updateItem)
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
}
// Core actions
menuItems = append(menuItems,
actionItem,
)
// Only show WebUI option if LocalAI is installed
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() && sm.launcher.IsRunning() {
menuItems = append(menuItems,
fyne.NewMenuItem("Open WebUI", func() {
sm.openWebUI()
}),
)
}
menuItems = append(menuItems,
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Check for Updates", func() {
sm.checkForUpdates()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Settings", func() {
sm.showSettings()
}),
fyne.NewMenuItem("Show Welcome Window", func() {
sm.showWelcomeWindow()
}),
fyne.NewMenuItem("Open Data Folder", func() {
sm.openDataFolder()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Documentation", func() {
sm.openDocumentation()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Quit", func() {
// Perform cleanup before quitting
if err := sm.launcher.Shutdown(); err != nil {
log.Printf("Error during shutdown: %v", err)
}
sm.app.Quit()
}),
)
menu := fyne.NewMenu("LocalAI", menuItems...)
sm.desk.SetSystemTrayMenu(menu)
}
// UpdateRunningState updates the systray based on running state
func (sm *SystrayManager) UpdateRunningState(isRunning bool) {
sm.updateStartStopItem()
}
// UpdateStatus updates the systray menu to reflect status changes
func (sm *SystrayManager) UpdateStatus(status string) {
sm.recreateMenu()
}
// checkForUpdates checks for available updates
func (sm *SystrayManager) checkForUpdates() {
go func() {
log.Printf("Checking for updates...")
available, version, err := sm.launcher.CheckForUpdates()
if err != nil {
log.Printf("Failed to check for updates: %v", err)
return
}
log.Printf("Update check result: available=%v, version=%s", available, version)
if available {
sm.hasUpdateAvailable = true
sm.latestVersion = version
sm.recreateMenu()
}
}()
}
// downloadUpdate downloads the latest update
func (sm *SystrayManager) downloadUpdate() {
if !sm.hasUpdateAvailable {
return
}
// Show progress window
sm.showDownloadProgress(sm.latestVersion)
}
// showSettings shows the settings window
func (sm *SystrayManager) showSettings() {
sm.window.Show()
sm.window.RequestFocus()
}
// showWelcomeWindow shows the welcome window
func (sm *SystrayManager) showWelcomeWindow() {
if sm.launcher.GetUI() != nil {
sm.launcher.GetUI().ShowWelcomeWindow()
}
}
// openDataFolder opens the data folder in file manager
func (sm *SystrayManager) openDataFolder() {
dataPath := sm.launcher.GetDataPath()
if parsedURL, err := url.Parse("file://" + dataPath); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// NotifyUpdateAvailable sets update notification in systray
func (sm *SystrayManager) NotifyUpdateAvailable(version string) {
sm.hasUpdateAvailable = true
sm.latestVersion = version
sm.recreateMenu()
}
// truncateText truncates text to specified length and adds ellipsis if needed
func (sm *SystrayManager) truncateText(text string, maxLength int) string {
if len(text) <= maxLength {
return text
}
return text[:maxLength-3] + "..."
}
// showStatusDetails shows a detailed status window with full information
func (sm *SystrayManager) showStatusDetails(status, version string) {
fyne.DoAndWait(func() {
// Create status details window
statusWindow := sm.app.NewWindow("LocalAI Status Details")
statusWindow.Resize(fyne.NewSize(500, 400))
statusWindow.CenterOnScreen()
// Status information
statusLabel := widget.NewLabel("Current Status:")
statusValue := widget.NewLabel(status)
statusValue.Wrapping = fyne.TextWrapWord
// Version information (only show if version exists)
var versionContainer fyne.CanvasObject
if version != "" {
versionLabel := widget.NewLabel("Installed Version:")
versionValue := widget.NewLabel(version)
versionValue.Wrapping = fyne.TextWrapWord
versionContainer = container.NewVBox(versionLabel, versionValue)
}
// Running state
runningLabel := widget.NewLabel("Running State:")
runningValue := widget.NewLabel("")
if sm.launcher.IsRunning() {
runningValue.SetText("🟢 Running")
} else {
runningValue.SetText("🔴 Stopped")
}
// WebUI URL
webuiLabel := widget.NewLabel("WebUI URL:")
webuiValue := widget.NewLabel(sm.launcher.GetWebUIURL())
webuiValue.Wrapping = fyne.TextWrapWord
// Recent logs (last 20 lines)
logsLabel := widget.NewLabel("Recent Logs:")
logsText := widget.NewMultiLineEntry()
logsText.SetText(sm.launcher.GetRecentLogs())
logsText.Wrapping = fyne.TextWrapWord
logsText.Disable() // Make it read-only
// Buttons
closeButton := widget.NewButton("Close", func() {
statusWindow.Close()
})
refreshButton := widget.NewButton("Refresh", func() {
// Refresh the status information
statusValue.SetText(sm.launcher.GetLastStatus())
// Note: Version refresh is not implemented for simplicity
// The version will be updated when the status details window is reopened
if sm.launcher.IsRunning() {
runningValue.SetText("🟢 Running")
} else {
runningValue.SetText("🔴 Stopped")
}
logsText.SetText(sm.launcher.GetRecentLogs())
})
openWebUIButton := widget.NewButton("Open WebUI", func() {
sm.openWebUI()
})
// Layout
buttons := container.NewHBox(closeButton, refreshButton, openWebUIButton)
// Build info container dynamically
infoItems := []fyne.CanvasObject{
statusLabel, statusValue,
widget.NewSeparator(),
}
// Add version section if it exists
if versionContainer != nil {
infoItems = append(infoItems, versionContainer, widget.NewSeparator())
}
infoItems = append(infoItems,
runningLabel, runningValue,
widget.NewSeparator(),
webuiLabel, webuiValue,
)
infoContainer := container.NewVBox(infoItems...)
content := container.NewVBox(
infoContainer,
widget.NewSeparator(),
logsLabel,
logsText,
widget.NewSeparator(),
buttons,
)
statusWindow.SetContent(content)
statusWindow.Show()
})
}
// showErrorDialog shows a simple error dialog
func (sm *SystrayManager) showErrorDialog(title, message string) {
fyne.DoAndWait(func() {
dialog.ShowError(fmt.Errorf(message), sm.window)
})
}
// showStartupErrorDialog shows a detailed error dialog with process logs
func (sm *SystrayManager) showStartupErrorDialog(err error) {
fyne.DoAndWait(func() {
// Get the recent process logs (more useful for debugging)
logs := sm.launcher.GetRecentLogs()
// Create error window
errorWindow := sm.app.NewWindow("LocalAI Startup Failed")
errorWindow.Resize(fyne.NewSize(600, 500))
errorWindow.CenterOnScreen()
// Error message
errorLabel := widget.NewLabel(fmt.Sprintf("Failed to start LocalAI:\n%s", err.Error()))
errorLabel.Wrapping = fyne.TextWrapWord
// Logs display
logsLabel := widget.NewLabel("Process Logs:")
logsText := widget.NewMultiLineEntry()
logsText.SetText(logs)
logsText.Wrapping = fyne.TextWrapWord
logsText.Disable() // Make it read-only
// Buttons
closeButton := widget.NewButton("Close", func() {
errorWindow.Close()
})
retryButton := widget.NewButton("Retry", func() {
errorWindow.Close()
// Try to start again
go func() {
if retryErr := sm.launcher.StartLocalAI(); retryErr != nil {
sm.showStartupErrorDialog(retryErr)
}
}()
})
openLogsButton := widget.NewButton("Open Logs Folder", func() {
sm.openDataFolder()
})
// Layout
buttons := container.NewHBox(closeButton, retryButton, openLogsButton)
content := container.NewVBox(
errorLabel,
widget.NewSeparator(),
logsLabel,
logsText,
widget.NewSeparator(),
buttons,
)
errorWindow.SetContent(content)
errorWindow.Show()
})
}
// showDownloadProgress shows a progress window for downloading updates
func (sm *SystrayManager) showDownloadProgress(version string) {
// Create a new window for download progress
progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
sm.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := sm.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show restart dialog
dialog.ShowConfirm("Update Downloaded",
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
func(restart bool) {
if restart {
sm.app.Quit()
}
progressWindow.Close()
}, progressWindow)
}
})
// Update systray menu
if err == nil {
sm.hasUpdateAvailable = false
sm.latestVersion = ""
sm.recreateMenu()
}
}()
}

View File

@@ -1,795 +0,0 @@
package launcher
import (
"fmt"
"log"
"net/url"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
)
// EnvVar represents an environment variable
type EnvVar struct {
Key string
Value string
}
// LauncherUI handles the user interface
type LauncherUI struct {
// Status display
statusLabel *widget.Label
versionLabel *widget.Label
// Control buttons
startStopButton *widget.Button
webUIButton *widget.Button
updateButton *widget.Button
downloadButton *widget.Button
// Configuration
modelsPathEntry *widget.Entry
backendsPathEntry *widget.Entry
addressEntry *widget.Entry
logLevelSelect *widget.Select
startOnBootCheck *widget.Check
// Environment Variables
envVarsData []EnvVar
newEnvKeyEntry *widget.Entry
newEnvValueEntry *widget.Entry
updateEnvironmentDisplay func()
// Logs
logText *widget.Entry
// Progress
progressBar *widget.ProgressBar
// Update management
latestVersion string
// Reference to launcher
launcher *Launcher
}
// NewLauncherUI creates a new UI instance
func NewLauncherUI() *LauncherUI {
return &LauncherUI{
statusLabel: widget.NewLabel("Initializing..."),
versionLabel: widget.NewLabel("Version: Unknown"),
startStopButton: widget.NewButton("Start LocalAI", nil),
webUIButton: widget.NewButton("Open WebUI", nil),
updateButton: widget.NewButton("Check for Updates", nil),
modelsPathEntry: widget.NewEntry(),
backendsPathEntry: widget.NewEntry(),
addressEntry: widget.NewEntry(),
logLevelSelect: widget.NewSelect([]string{"error", "warn", "info", "debug", "trace"}, nil),
startOnBootCheck: widget.NewCheck("Start LocalAI on system boot", nil),
logText: widget.NewMultiLineEntry(),
progressBar: widget.NewProgressBar(),
envVarsData: []EnvVar{}, // Initialize the environment variables slice
}
}
// CreateMainUI creates the main UI layout
func (ui *LauncherUI) CreateMainUI(launcher *Launcher) *fyne.Container {
ui.launcher = launcher
ui.setupBindings()
// Main tab with status and controls
// Configuration is now the main content
configTab := ui.createConfigTab()
// Create a simple container instead of tabs since we only have settings
tabs := container.NewVBox(
widget.NewCard("LocalAI Launcher Settings", "", configTab),
)
return tabs
}
// createConfigTab creates the configuration tab
func (ui *LauncherUI) createConfigTab() *fyne.Container {
// Path configuration
pathsCard := widget.NewCard("Paths", "", container.NewGridWithColumns(2,
widget.NewLabel("Models Path:"),
ui.modelsPathEntry,
widget.NewLabel("Backends Path:"),
ui.backendsPathEntry,
))
// Server configuration
serverCard := widget.NewCard("Server", "", container.NewVBox(
container.NewGridWithColumns(2,
widget.NewLabel("Address:"),
ui.addressEntry,
widget.NewLabel("Log Level:"),
ui.logLevelSelect,
),
ui.startOnBootCheck,
))
// Save button
saveButton := widget.NewButton("Save Configuration", func() {
ui.saveConfiguration()
})
// Environment Variables section
envCard := ui.createEnvironmentSection()
return container.NewVBox(
pathsCard,
serverCard,
envCard,
saveButton,
)
}
// createEnvironmentSection creates the environment variables section for the config tab
func (ui *LauncherUI) createEnvironmentSection() *fyne.Container {
// Initialize environment variables widgets
ui.newEnvKeyEntry = widget.NewEntry()
ui.newEnvKeyEntry.SetPlaceHolder("Environment Variable Name")
ui.newEnvValueEntry = widget.NewEntry()
ui.newEnvValueEntry.SetPlaceHolder("Environment Variable Value")
// Add button
addButton := widget.NewButton("Add Environment Variable", func() {
ui.addEnvironmentVariable()
})
// Environment variables list with delete buttons
ui.envVarsData = []EnvVar{}
// Create container for environment variables
envVarsContainer := container.NewVBox()
// Update function to rebuild the environment variables display
ui.updateEnvironmentDisplay = func() {
envVarsContainer.Objects = nil
for i, envVar := range ui.envVarsData {
index := i // Capture index for closure
// Create row with label and delete button
envLabel := widget.NewLabel(fmt.Sprintf("%s = %s", envVar.Key, envVar.Value))
deleteBtn := widget.NewButton("Delete", func() {
ui.confirmDeleteEnvironmentVariable(index)
})
deleteBtn.Importance = widget.DangerImportance
row := container.NewBorder(nil, nil, nil, deleteBtn, envLabel)
envVarsContainer.Add(row)
}
envVarsContainer.Refresh()
}
// Create a scrollable container for the environment variables
envScroll := container.NewScroll(envVarsContainer)
envScroll.SetMinSize(fyne.NewSize(400, 150))
// Input section for adding new environment variables
inputSection := container.NewVBox(
container.NewGridWithColumns(2,
ui.newEnvKeyEntry,
ui.newEnvValueEntry,
),
addButton,
)
// Environment variables card
envCard := widget.NewCard("Environment Variables", "", container.NewVBox(
inputSection,
widget.NewSeparator(),
envScroll,
))
return container.NewVBox(envCard)
}
// addEnvironmentVariable adds a new environment variable
func (ui *LauncherUI) addEnvironmentVariable() {
key := ui.newEnvKeyEntry.Text
value := ui.newEnvValueEntry.Text
log.Printf("addEnvironmentVariable: attempting to add %s=%s", key, value)
log.Printf("addEnvironmentVariable: current ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
if key == "" {
log.Printf("addEnvironmentVariable: key is empty, showing error")
dialog.ShowError(fmt.Errorf("environment variable name cannot be empty"), ui.launcher.window)
return
}
// Check if key already exists
for _, envVar := range ui.envVarsData {
if envVar.Key == key {
log.Printf("addEnvironmentVariable: key %s already exists, showing error", key)
dialog.ShowError(fmt.Errorf("environment variable '%s' already exists", key), ui.launcher.window)
return
}
}
log.Printf("addEnvironmentVariable: adding new env var %s=%s", key, value)
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
log.Printf("addEnvironmentVariable: after adding, ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
fyne.Do(func() {
if ui.updateEnvironmentDisplay != nil {
ui.updateEnvironmentDisplay()
}
// Clear input fields
ui.newEnvKeyEntry.SetText("")
ui.newEnvValueEntry.SetText("")
})
log.Printf("addEnvironmentVariable: calling saveEnvironmentVariables")
// Save to configuration
ui.saveEnvironmentVariables()
}
// removeEnvironmentVariable removes an environment variable by index
func (ui *LauncherUI) removeEnvironmentVariable(index int) {
if index >= 0 && index < len(ui.envVarsData) {
ui.envVarsData = append(ui.envVarsData[:index], ui.envVarsData[index+1:]...)
fyne.Do(func() {
if ui.updateEnvironmentDisplay != nil {
ui.updateEnvironmentDisplay()
}
})
ui.saveEnvironmentVariables()
}
}
// saveEnvironmentVariables saves environment variables to the configuration
func (ui *LauncherUI) saveEnvironmentVariables() {
if ui.launcher == nil {
log.Printf("saveEnvironmentVariables: launcher is nil")
return
}
config := ui.launcher.GetConfig()
log.Printf("saveEnvironmentVariables: before - Environment vars: %v", config.EnvironmentVars)
config.EnvironmentVars = make(map[string]string)
for _, envVar := range ui.envVarsData {
config.EnvironmentVars[envVar.Key] = envVar.Value
log.Printf("saveEnvironmentVariables: adding %s=%s", envVar.Key, envVar.Value)
}
log.Printf("saveEnvironmentVariables: after - Environment vars: %v", config.EnvironmentVars)
log.Printf("saveEnvironmentVariables: calling SetConfig with %d environment variables", len(config.EnvironmentVars))
err := ui.launcher.SetConfig(config)
if err != nil {
log.Printf("saveEnvironmentVariables: failed to save config: %v", err)
} else {
log.Printf("saveEnvironmentVariables: config saved successfully")
}
}
// confirmDeleteEnvironmentVariable shows confirmation dialog for deleting an environment variable
func (ui *LauncherUI) confirmDeleteEnvironmentVariable(index int) {
if index >= 0 && index < len(ui.envVarsData) {
envVar := ui.envVarsData[index]
dialog.ShowConfirm("Remove Environment Variable",
fmt.Sprintf("Remove environment variable '%s'?", envVar.Key),
func(remove bool) {
if remove {
ui.removeEnvironmentVariable(index)
}
}, ui.launcher.window)
}
}
// setupBindings sets up event handlers for UI elements
func (ui *LauncherUI) setupBindings() {
// Start/Stop button
ui.startStopButton.OnTapped = func() {
if ui.launcher.IsRunning() {
ui.stopLocalAI()
} else {
ui.startLocalAI()
}
}
// WebUI button
ui.webUIButton.OnTapped = func() {
ui.openWebUI()
}
ui.webUIButton.Disable() // Disabled until LocalAI is running
// Update button
ui.updateButton.OnTapped = func() {
ui.checkForUpdates()
}
// Log level selection
ui.logLevelSelect.OnChanged = func(selected string) {
if ui.launcher != nil {
config := ui.launcher.GetConfig()
config.LogLevel = selected
ui.launcher.SetConfig(config)
}
}
}
// startLocalAI starts the LocalAI service
func (ui *LauncherUI) startLocalAI() {
fyne.Do(func() {
ui.startStopButton.Disable()
})
ui.UpdateStatus("Starting LocalAI...")
go func() {
err := ui.launcher.StartLocalAI()
if err != nil {
ui.UpdateStatus("Failed to start: " + err.Error())
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.Do(func() {
ui.startStopButton.SetText("Stop LocalAI")
ui.webUIButton.Enable()
})
}
fyne.Do(func() {
ui.startStopButton.Enable()
})
}()
}
// stopLocalAI stops the LocalAI service
func (ui *LauncherUI) stopLocalAI() {
fyne.Do(func() {
ui.startStopButton.Disable()
})
ui.UpdateStatus("Stopping LocalAI...")
go func() {
err := ui.launcher.StopLocalAI()
if err != nil {
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.Do(func() {
ui.startStopButton.SetText("Start LocalAI")
ui.webUIButton.Disable()
})
}
fyne.Do(func() {
ui.startStopButton.Enable()
})
}()
}
// openWebUI opens the LocalAI WebUI in the default browser
func (ui *LauncherUI) openWebUI() {
webURL := ui.launcher.GetWebUIURL()
parsedURL, err := url.Parse(webURL)
if err != nil {
dialog.ShowError(err, ui.launcher.window)
return
}
// Open URL in default browser
fyne.CurrentApp().OpenURL(parsedURL)
}
// saveConfiguration saves the current configuration
func (ui *LauncherUI) saveConfiguration() {
log.Printf("saveConfiguration: starting to save configuration")
config := ui.launcher.GetConfig()
log.Printf("saveConfiguration: current config Environment vars: %v", config.EnvironmentVars)
log.Printf("saveConfiguration: ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
config.ModelsPath = ui.modelsPathEntry.Text
config.BackendsPath = ui.backendsPathEntry.Text
config.Address = ui.addressEntry.Text
config.LogLevel = ui.logLevelSelect.Selected
config.StartOnBoot = ui.startOnBootCheck.Checked
// Ensure environment variables are included in the configuration
config.EnvironmentVars = make(map[string]string)
for _, envVar := range ui.envVarsData {
config.EnvironmentVars[envVar.Key] = envVar.Value
log.Printf("saveConfiguration: adding env var %s=%s", envVar.Key, envVar.Value)
}
log.Printf("saveConfiguration: final config Environment vars: %v", config.EnvironmentVars)
err := ui.launcher.SetConfig(config)
if err != nil {
log.Printf("saveConfiguration: failed to save config: %v", err)
dialog.ShowError(err, ui.launcher.window)
} else {
log.Printf("saveConfiguration: config saved successfully")
dialog.ShowInformation("Configuration", "Configuration saved successfully", ui.launcher.window)
}
}
// checkForUpdates checks for available updates
func (ui *LauncherUI) checkForUpdates() {
fyne.Do(func() {
ui.updateButton.Disable()
})
ui.UpdateStatus("Checking for updates...")
go func() {
available, version, err := ui.launcher.CheckForUpdates()
if err != nil {
ui.UpdateStatus("Failed to check updates: " + err.Error())
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else if available {
ui.latestVersion = version // Store the latest version
ui.UpdateStatus("Update available: " + version)
fyne.Do(func() {
if ui.downloadButton != nil {
ui.downloadButton.Enable()
}
})
ui.NotifyUpdateAvailable(version)
} else {
ui.UpdateStatus("No updates available")
fyne.DoAndWait(func() {
dialog.ShowInformation("Updates", "You are running the latest version", ui.launcher.window)
})
}
fyne.Do(func() {
ui.updateButton.Enable()
})
}()
}
// downloadUpdate downloads the latest update
func (ui *LauncherUI) downloadUpdate() {
// Use stored version or check for updates
version := ui.latestVersion
if version == "" {
_, v, err := ui.launcher.CheckForUpdates()
if err != nil {
dialog.ShowError(err, ui.launcher.window)
return
}
version = v
ui.latestVersion = version
}
if version == "" {
dialog.ShowError(fmt.Errorf("no version information available"), ui.launcher.window)
return
}
// Disable buttons during download
if ui.downloadButton != nil {
fyne.Do(func() {
ui.downloadButton.Disable()
})
}
fyne.Do(func() {
ui.progressBar.Show()
ui.progressBar.SetValue(0)
})
ui.UpdateStatus("Downloading update " + version + "...")
go func() {
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
ui.progressBar.SetValue(progress)
})
// Update status with percentage
percentage := int(progress * 100)
ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
})
fyne.Do(func() {
ui.progressBar.Hide()
})
// Re-enable buttons after download
if ui.downloadButton != nil {
fyne.Do(func() {
ui.downloadButton.Enable()
})
}
if err != nil {
fyne.DoAndWait(func() {
ui.UpdateStatus("Failed to download update: " + err.Error())
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.DoAndWait(func() {
ui.UpdateStatus("Update downloaded successfully")
dialog.ShowInformation("Update", "Update downloaded successfully. Please restart the launcher to use the new version.", ui.launcher.window)
})
}
}()
}
// UpdateStatus updates the status label
func (ui *LauncherUI) UpdateStatus(status string) {
if ui.statusLabel != nil {
fyne.Do(func() {
ui.statusLabel.SetText(status)
})
}
}
// OnLogUpdate handles new log content
func (ui *LauncherUI) OnLogUpdate(logLine string) {
if ui.logText != nil {
fyne.Do(func() {
currentText := ui.logText.Text
ui.logText.SetText(currentText + logLine)
// Auto-scroll to bottom (simplified)
ui.logText.CursorRow = len(ui.logText.Text)
})
}
}
// NotifyUpdateAvailable shows an update notification
func (ui *LauncherUI) NotifyUpdateAvailable(version string) {
if ui.launcher != nil && ui.launcher.window != nil {
fyne.DoAndWait(func() {
dialog.ShowConfirm("Update Available",
"A new version ("+version+") is available. Would you like to download it?",
func(confirmed bool) {
if confirmed {
ui.downloadUpdate()
}
}, ui.launcher.window)
})
}
}
// LoadConfiguration loads the current configuration into UI elements
func (ui *LauncherUI) LoadConfiguration() {
if ui.launcher == nil {
log.Printf("UI LoadConfiguration: launcher is nil")
return
}
config := ui.launcher.GetConfig()
log.Printf("UI LoadConfiguration: loading config - ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
config.ModelsPath, config.BackendsPath, config.Address, config.LogLevel)
log.Printf("UI LoadConfiguration: Environment vars: %v", config.EnvironmentVars)
ui.modelsPathEntry.SetText(config.ModelsPath)
ui.backendsPathEntry.SetText(config.BackendsPath)
ui.addressEntry.SetText(config.Address)
ui.logLevelSelect.SetSelected(config.LogLevel)
ui.startOnBootCheck.SetChecked(config.StartOnBoot)
// Load environment variables
ui.envVarsData = []EnvVar{}
for key, value := range config.EnvironmentVars {
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
}
if ui.updateEnvironmentDisplay != nil {
fyne.Do(func() {
ui.updateEnvironmentDisplay()
})
}
// Update version display
version := ui.launcher.GetCurrentVersion()
ui.versionLabel.SetText("Version: " + version)
log.Printf("UI LoadConfiguration: configuration loaded successfully")
}
// showDownloadProgress shows a progress window for downloading LocalAI
func (ui *LauncherUI) showDownloadProgress(version, title string) {
fyne.DoAndWait(func() {
// Create progress window using the launcher's app
progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
ui.launcher.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(title),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show success dialog
dialog.ShowConfirm("Installation Complete",
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
func(close bool) {
progressWindow.Close()
// Update status
ui.UpdateStatus("LocalAI installed successfully")
}, progressWindow)
}
})
}()
})
}
// UpdateRunningState updates UI based on LocalAI running state
func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
fyne.Do(func() {
if isRunning {
ui.startStopButton.SetText("Stop LocalAI")
ui.webUIButton.Enable()
} else {
ui.startStopButton.SetText("Start LocalAI")
ui.webUIButton.Disable()
}
})
}
// ShowWelcomeWindow displays the welcome window with helpful information
func (ui *LauncherUI) ShowWelcomeWindow() {
if ui.launcher == nil || ui.launcher.window == nil {
log.Printf("Cannot show welcome window: launcher or window is nil")
return
}
fyne.DoAndWait(func() {
// Create welcome window
welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher")
welcomeWindow.Resize(fyne.NewSize(600, 500))
welcomeWindow.CenterOnScreen()
welcomeWindow.SetCloseIntercept(func() {
welcomeWindow.Close()
})
// Title
titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!")
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
// Welcome message
welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system.
What you can do:
• Start and stop LocalAI server
• Configure models and backends paths
• Set environment variables
• Check for updates automatically
• Access LocalAI WebUI when running
Getting Started:
1. Configure your models and backends paths
2. Click "Start LocalAI" to begin
3. Use "Open WebUI" to access the interface
4. Check the system tray for quick access`
welcomeLabel := widget.NewLabel(welcomeText)
welcomeLabel.Wrapping = fyne.TextWrapWord
// Useful links section
linksTitle := widget.NewLabel("Useful Links:")
linksTitle.TextStyle = fyne.TextStyle{Bold: true}
// Create link buttons
docsButton := widget.NewButton("📚 Documentation", func() {
ui.openURL("https://localai.io/docs/")
})
githubButton := widget.NewButton("🐙 GitHub Repository", func() {
ui.openURL("https://github.com/mudler/LocalAI")
})
modelsButton := widget.NewButton("🤖 Model Gallery", func() {
ui.openURL("https://localai.io/models/")
})
communityButton := widget.NewButton("💬 Community", func() {
ui.openURL("https://discord.gg/XgwjKptP7Z")
})
// Checkbox to disable welcome window
dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) {
if ui.launcher != nil {
config := ui.launcher.GetConfig()
v := !checked
config.ShowWelcome = &v
ui.launcher.SetConfig(config)
}
})
config := ui.launcher.GetConfig()
if config.ShowWelcome != nil {
dontShowAgainCheck.SetChecked(*config.ShowWelcome)
}
// Close button
closeButton := widget.NewButton("Get Started", func() {
welcomeWindow.Close()
})
closeButton.Importance = widget.HighImportance
// Layout
linksContainer := container.NewVBox(
linksTitle,
docsButton,
githubButton,
modelsButton,
communityButton,
)
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
welcomeLabel,
widget.NewSeparator(),
linksContainer,
widget.NewSeparator(),
dontShowAgainCheck,
widget.NewSeparator(),
closeButton,
)
welcomeWindow.SetContent(content)
welcomeWindow.Show()
})
}
// openURL opens a URL in the default browser
func (ui *LauncherUI) openURL(urlString string) {
parsedURL, err := url.Parse(urlString)
if err != nil {
log.Printf("Failed to parse URL %s: %v", urlString, err)
return
}
fyne.CurrentApp().OpenURL(parsedURL)
}

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.0 KiB

View File

@@ -1,92 +0,0 @@
package main
import (
"log"
"os"
"os/signal"
"syscall"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/driver/desktop"
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
func main() {
// Create the application with unique ID
myApp := app.NewWithID("com.localai.launcher")
myApp.SetIcon(resourceIconPng)
myWindow := myApp.NewWindow("LocalAI Launcher")
myWindow.Resize(fyne.NewSize(800, 600))
// Create the launcher UI
ui := coreLauncher.NewLauncherUI()
// Initialize the launcher with UI context
launcher := coreLauncher.NewLauncher(ui, myWindow, myApp)
// Setup the UI
content := ui.CreateMainUI(launcher)
myWindow.SetContent(content)
// Setup window close behavior - minimize to tray instead of closing
myWindow.SetCloseIntercept(func() {
myWindow.Hide()
})
// Setup system tray using Fyne's built-in approach``
if desk, ok := myApp.(desktop.App); ok {
// Create a dynamic systray manager
systray := coreLauncher.NewSystrayManager(launcher, myWindow, desk, myApp, resourceIconPng)
launcher.SetSystray(systray)
}
// Setup signal handling for graceful shutdown
setupSignalHandling(launcher)
// Initialize the launcher state
go func() {
if err := launcher.Initialize(); err != nil {
log.Printf("Failed to initialize launcher: %v", err)
if launcher.GetUI() != nil {
launcher.GetUI().UpdateStatus("Failed to initialize: " + err.Error())
}
} else {
// Load configuration into UI
launcher.GetUI().LoadConfiguration()
launcher.GetUI().UpdateStatus("Ready")
// Show welcome window if configured to do so
config := launcher.GetConfig()
if *config.ShowWelcome {
launcher.GetUI().ShowWelcomeWindow()
}
}
}()
// Run the application in background (window only shown when "Settings" is clicked)
myApp.Run()
}
// setupSignalHandling sets up signal handlers for graceful shutdown
func setupSignalHandling(launcher *coreLauncher.Launcher) {
// Create a channel to receive OS signals
sigChan := make(chan os.Signal, 1)
// Register for interrupt and terminate signals
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Handle signals in a separate goroutine
go func() {
sig := <-sigChan
log.Printf("Received signal %v, shutting down gracefully...", sig)
// Perform cleanup
if err := launcher.Shutdown(); err != nil {
log.Printf("Error during shutdown: %v", err)
}
// Exit the application
os.Exit(0)
}()
}

View File

@@ -56,12 +56,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
@@ -87,13 +87,13 @@ func New(opts ...config.AppOption) (*Application, error) {
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
return nil, err
}
}

View File

@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err

View File

@@ -78,12 +78,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
b = c.Batch
}
flashAttention := "auto"
if c.FlashAttention != nil {
flashAttention = *c.FlashAttention
}
f16 := false
if c.F16 != nil {
f16 = *c.F16
@@ -172,7 +166,7 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
MMProj: c.MMProj,
FlashAttention: flashAttention,
FlashAttention: c.FlashAttention,
CacheTypeKey: c.CacheTypeK,
CacheTypeValue: c.CacheTypeV,
NoKVOffload: c.NoKVOffloading,

View File

@@ -12,7 +12,7 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
@@ -34,7 +34,6 @@ func ModelTranscription(audio, language string, translate bool, diarize bool, ml
Dst: audio,
Language: language,
Translate: translate,
Diarize: diarize,
Threads: uint32(*modelConfig.Threads),
})
if err != nil {

View File

@@ -7,7 +7,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(
@@ -22,18 +22,12 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
_, err := inferenceModel.GenerateVideo(
appConfig.Context,
&proto.GenerateVideoRequest{
Height: height,
Width: width,
Prompt: prompt,
NegativePrompt: negativePrompt,
StartImage: startImage,
EndImage: endImage,
NumFrames: numFrames,
Fps: fps,
Seed: seed,
CfgScale: cfgScale,
Step: step,
Dst: dst,
Height: height,
Width: width,
Prompt: prompt,
StartImage: startImage,
EndImage: endImage,
Dst: dst,
})
return err
}

View File

@@ -6,7 +6,6 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/core/gallery"
@@ -101,8 +100,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -5,7 +5,6 @@ import (
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http"
)
@@ -46,7 +45,5 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)
signals.Handler(nil)
return appHTTP.Listen(e.Address)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/p2p"
)
@@ -20,7 +19,5 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
signals.Handler(nil)
return fs.Start(context.Background())
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
"github.com/schollz/progressbar/v3"
@@ -126,8 +125,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -10,11 +10,9 @@ import (
"github.com/mudler/LocalAI/core/application"
cli_api "github.com/mudler/LocalAI/core/cli/api"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@@ -75,16 +73,9 @@ type RunCMD struct {
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
Version bool
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.Version {
fmt.Println(internal.Version)
return nil
}
os.MkdirAll(r.BackendsPath, 0750)
os.MkdirAll(r.ModelsPath, 0750)
@@ -225,8 +216,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
return err
}
// Catch signals from the OS requesting us to exit, and stop all backends
signals.Handler(app.ModelLoader())
return appHTTP.Listen(r.Address)
}

View File

@@ -1,25 +0,0 @@
package signals
import (
"os"
"os/signal"
"syscall"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
func Handler(m *model.ModelLoader) {
// Catch signals from the OS requesting us to exit, and stop all backends
go func(m *model.ModelLoader) {
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
<-c
if m != nil {
if err := m.StopAllGRPC(); err != nil {
log.Error().Err(err).Msg("error while stopping all grpc backends")
}
}
os.Exit(1)
}(m)
}

View File

@@ -20,7 +20,6 @@ type TranscriptCMD struct {
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"c" help:"Translate the transcription to english"`
Diarize bool `short:"d" help:"Mark speaker turns"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
}
@@ -57,7 +56,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
}
}()
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
if err != nil {
return err
}

View File

@@ -2,7 +2,6 @@ package worker
type WorkerFlags struct {
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
}

View File

@@ -1,7 +1,6 @@
package worker
import (
"encoding/json"
"errors"
"fmt"
"os"
@@ -10,10 +9,7 @@ import (
"syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -24,10 +20,9 @@ type LLamaCPP struct {
const (
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
llamaCPPGalleryName = "llama-cpp"
)
func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
log.Warn().Msgf("Failed listing system backends: %s", err)
@@ -35,19 +30,9 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
}
log.Debug().Msgf("System backends: %v", backends)
backend, ok := backends.Get(llamaCPPGalleryName)
backend, ok := backends.Get("llama-cpp")
if !ok {
ml := model.NewModelLoader(systemState, true)
var gals []config.Gallery
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err
}
return "", errors.New("llama-cpp backend not found, install it first")
}
backendPath := filepath.Dir(backend.RunFile)
@@ -76,7 +61,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
if err != nil {
return err
}
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
return err
}
@@ -84,9 +69,6 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
args := strings.Split(r.ExtraLLamaCPPArgs, " ")
args = append([]string{grpcProcess}, args...)
signals.Handler(nil)
return syscall.Exec(
grpcProcess,
args,

View File

@@ -9,7 +9,6 @@ import (
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/pkg/system"
"github.com/phayes/freeport"
@@ -70,7 +69,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
for {
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
return
@@ -107,8 +106,6 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
}
}
signals.Handler(nil)
for {
time.Sleep(1 * time.Second)
}

View File

@@ -164,10 +164,10 @@ type LLMConfig struct {
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj" json:"mmproj"`
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
FlashAttention bool `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
ModelType string `yaml:"type" json:"type"`

View File

@@ -1,5 +1,3 @@
// Package gallery provides installation and registration utilities for LocalAI backends,
// including meta-backend resolution based on system capabilities.
package gallery
import (
@@ -7,7 +5,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/mudler/LocalAI/core/config"
@@ -23,12 +20,6 @@ const (
runFile = "run.sh"
)
// backendCandidate represents an installed concrete backend option for a given alias
type backendCandidate struct {
name string
runFile string
}
// readBackendMetadata reads the metadata JSON file for a backend
func readBackendMetadata(backendPath string) (*BackendMetadata, error) {
metadataPath := filepath.Join(backendPath, metadataFile)
@@ -67,8 +58,8 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
return nil
}
// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
// Installs a model from the gallery
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -108,7 +99,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(systemState, bestBackend, downloadStatus); err != nil {
return err
}
@@ -133,10 +124,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
return InstallBackend(systemState, backend, downloadStatus)
}
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(systemState *system.SystemState, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -194,7 +185,7 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
return fmt.Errorf("failed to write metadata for backend %q: %v", name, err)
}
return RegisterBackends(systemState, modelLoader)
return nil
}
func DeleteBackendFromSystem(systemState *system.SystemState, name string) error {
@@ -291,18 +282,23 @@ func (b SystemBackends) GetAll() []SystemBackend {
}
func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) {
// Gather backends from system and user paths, then resolve alias conflicts by capability.
potentialBackends, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return nil, err
}
backends := make(SystemBackends)
// System-provided backends
if systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath); err == nil {
systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath)
if err == nil {
// system backends are special, they are provided by the system and not managed by LocalAI
for _, systemBackend := range systemBackends {
if systemBackend.IsDir() {
run := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
if _, err := os.Stat(run); err == nil {
systemBackendRunFile := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
if _, err := os.Stat(systemBackendRunFile); err == nil {
backends[systemBackend.Name()] = SystemBackend{
Name: systemBackend.Name(),
RunFile: run,
RunFile: filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile),
IsMeta: false,
IsSystem: true,
Metadata: nil,
@@ -311,103 +307,63 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
}
}
} else {
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
log.Warn().Err(err).Msg("Failed to read system backends, but that's ok, we will just use the backends managed by LocalAI")
}
// User-managed backends and alias collection
entries, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return nil, err
}
for _, potentialBackend := range potentialBackends {
if potentialBackend.IsDir() {
potentialBackendRunFile := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), runFile)
aliasGroups := make(map[string][]backendCandidate)
metaMap := make(map[string]*BackendMetadata)
var metadata *BackendMetadata
for _, e := range entries {
if !e.IsDir() {
continue
}
dir := e.Name()
run := filepath.Join(systemState.Backend.BackendsPath, dir, runFile)
var metadata *BackendMetadata
metadataPath := filepath.Join(systemState.Backend.BackendsPath, dir, metadataFile)
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
metadata = &BackendMetadata{Name: dir}
} else {
m, rerr := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, dir))
if rerr != nil {
return nil, rerr
}
if m == nil {
metadata = &BackendMetadata{Name: dir}
// If metadata file does not exist, we just use the directory name
// and we do not fill the other metadata (such as potential backend Aliases)
metadataFilePath := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), metadataFile)
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
metadata = &BackendMetadata{
Name: potentialBackend.Name(),
}
} else {
metadata = m
}
}
metaMap[dir] = metadata
// Concrete backend entry
if _, err := os.Stat(run); err == nil {
backends[dir] = SystemBackend{
Name: dir,
RunFile: run,
IsMeta: false,
Metadata: metadata,
}
}
// Alias candidates
if metadata.Alias != "" {
aliasGroups[metadata.Alias] = append(aliasGroups[metadata.Alias], backendCandidate{name: dir, runFile: run})
}
// Meta backends indirection
if metadata.MetaBackendFor != "" {
backends[metadata.Name] = SystemBackend{
Name: metadata.Name,
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
IsMeta: true,
Metadata: metadata,
}
}
}
// Resolve aliases using system capability preferences
tokens := systemState.BackendPreferenceTokens()
for alias, cands := range aliasGroups {
chosen := backendCandidate{}
// Try preference tokens
for _, t := range tokens {
for _, c := range cands {
if strings.Contains(strings.ToLower(c.name), t) && c.runFile != "" {
chosen = c
break
// Check for alias in metadata
metadata, err = readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name()))
if err != nil {
return nil, err
}
}
if chosen.runFile != "" {
break
}
}
// Fallback: first runnable
if chosen.runFile == "" {
for _, c := range cands {
if c.runFile != "" {
chosen = c
break
if !backends.Exists(potentialBackend.Name()) {
// We don't want to override aliases if already set, and if we are meta backend
if _, err := os.Stat(potentialBackendRunFile); err == nil {
backends[potentialBackend.Name()] = SystemBackend{
Name: potentialBackend.Name(),
RunFile: potentialBackendRunFile,
IsMeta: false,
Metadata: metadata,
}
}
}
if metadata == nil {
continue
}
if metadata.Alias != "" {
backends[metadata.Alias] = SystemBackend{
Name: metadata.Alias,
RunFile: potentialBackendRunFile,
IsMeta: false,
Metadata: metadata,
}
}
if metadata.MetaBackendFor != "" {
backends[metadata.Name] = SystemBackend{
Name: metadata.Name,
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
IsMeta: true,
Metadata: metadata,
}
}
}
if chosen.runFile == "" {
continue
}
md := metaMap[chosen.name]
backends[alias] = SystemBackend{
Name: alias,
RunFile: chosen.runFile,
IsMeta: false,
Metadata: md,
}
}

View File

@@ -7,7 +7,6 @@ import (
"runtime"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -18,79 +17,10 @@ const (
testImage = "quay.io/mudler/tests:localai-backend-test"
)
var _ = Describe("Runtime capability-based backend selection", func() {
var tempDir string
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "gallery-caps-*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
It("ListSystemBackends prefers optimal alias candidate", func() {
// Arrange two installed backends sharing the same alias
must := func(err error) { Expect(err).NotTo(HaveOccurred()) }
cpuDir := filepath.Join(tempDir, "cpu-llama-cpp")
must(os.MkdirAll(cpuDir, 0o750))
cpuMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cpu-llama-cpp"}
b, _ := json.Marshal(cpuMeta)
must(os.WriteFile(filepath.Join(cpuDir, "metadata.json"), b, 0o644))
must(os.WriteFile(filepath.Join(cpuDir, "run.sh"), []byte(""), 0o755))
cudaDir := filepath.Join(tempDir, "cuda12-llama-cpp")
must(os.MkdirAll(cudaDir, 0o750))
cudaMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cuda12-llama-cpp"}
b, _ = json.Marshal(cudaMeta)
must(os.WriteFile(filepath.Join(cudaDir, "metadata.json"), b, 0o644))
must(os.WriteFile(filepath.Join(cudaDir, "run.sh"), []byte(""), 0o755))
// Default system: alias should point to CPU
sysDefault, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expect(aliasBack.RunFile).To(Equal(filepath.Join(cpuDir, "run.sh")))
// concrete entries remain
_, ok = backs.Get("cpu-llama-cpp")
Expect(ok).To(BeTrue())
_, ok = backs.Get("cuda12-llama-cpp")
Expect(ok).To(BeTrue())
// NVIDIA system: alias should point to CUDA
// Force capability to nvidia to make the test deterministic on platforms like darwin/arm64 (which default to metal)
os.Setenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY", "nvidia")
defer os.Unsetenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY")
sysNvidia, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expect(aliasBack.RunFile).To(Equal(filepath.Join(cudaDir, "run.sh")))
})
})
var _ = Describe("Gallery Backends", func() {
var (
tempDir string
galleries []config.Gallery
ml *model.ModelLoader
systemState *system.SystemState
tempDir string
galleries []config.Gallery
)
BeforeEach(func() {
@@ -105,9 +35,6 @@ var _ = Describe("Gallery Backends", func() {
URL: "https://gist.githubusercontent.com/mudler/71d5376bc2aa168873fa519fa9f4bd56/raw/0557f9c640c159fa8e4eab29e8d98df6a3d6e80f/backend-gallery.yaml",
},
}
systemState, err = system.GetSystemState(system.WithBackendPath(tempDir))
Expect(err).NotTo(HaveOccurred())
ml = model.NewModelLoader(systemState, true)
})
AfterEach(func() {
@@ -116,13 +43,21 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -298,7 +233,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -378,7 +313,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -462,7 +397,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -561,7 +496,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -593,7 +528,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -626,7 +561,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -647,7 +582,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

View File

@@ -11,7 +11,6 @@ import (
"github.com/mudler/LocalAI/core/config"
lconfig "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/utils"
@@ -74,7 +73,6 @@ type PromptTemplate struct {
func InstallModelFromGallery(
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
applyModel := func(model *GalleryModel) error {
@@ -133,7 +131,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}

View File

@@ -88,7 +88,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))

View File

@@ -836,40 +836,27 @@ var _ = Describe("API test", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
embeddingModel := openai.AdaEmbeddingV2
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: embeddingModel,
Model: openai.AdaEmbeddingV2,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred(), err)
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096))
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 2048))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 2048))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: embeddingModel,
Model: openai.AdaEmbeddingV2,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
resp3, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: embeddingModel,
Input: []string{"cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding))
Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding))
})
Context("External gRPC calls", func() {

View File

@@ -70,24 +70,6 @@ func infoButton(m *gallery.GalleryModel) elem.Node {
)
}
func getConfigButton(galleryName string) elem.Node {
return elem.Button(
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right ml-2 inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
"hx-swap": "outerHTML",
"hx-post": "browse/config/model/" + galleryName,
},
elem.I(
attrs.Props{
"class": "fa-solid fa-download pr-2",
},
),
elem.Text("Get Config"),
)
}
func deleteButton(galleryID string) elem.Node {
return elem.Button(
attrs.Props{

View File

@@ -339,12 +339,7 @@ func modelActionItems(m *gallery.GalleryModel, processTracker ProcessTracker, ga
reInstallButton(m.ID()),
deleteButton(m.ID()),
)),
// otherwise, show the install button, and the get config button
elem.Node(elem.Div(
attrs.Props{},
getConfigButton(m.ID()),
installButton(m.ID()),
)),
installButton(m.ID()),
),
),
),

View File

@@ -226,33 +226,3 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
return c.JSON(response)
}
}
// ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Preload the models
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to preload models: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Return success response
response := ModelResponse{
Success: true,
Message: "Model configurations reloaded successfully",
}
return c.Status(fiber.StatusOK).JSON(response)
}
}

Some files were not shown because too many files have changed in this diff Show More