Compare commits

..

2 Commits

Author SHA1 Message Date
Ettore Di Giacinto
478d2adfb7 Trigger CI
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-08-24 16:48:33 +02:00
Ettore Di Giacinto
909fdd1b0e feat(transformers): add support for CPU and MPS
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-08-24 16:47:51 +02:00
128 changed files with 941 additions and 6863 deletions

View File

@@ -6,10 +6,6 @@ models
backends
examples/chatbot-ui/models
backend/go/image/stablediffusion-ggml/build/
backend/go/*/build
backend/go/*/.cache
backend/go/*/sources
backend/go/*/package
examples/rwkv/models
examples/**/models
Dockerfile*

View File

@@ -2,6 +2,7 @@
name: 'build backend container images'
on:
pull_request:
push:
branches:
- master
@@ -63,6 +64,18 @@ jobs:
backend: "llama-cpp"
dockerfile: "./backend/Dockerfile.llama-cpp"
context: "./"
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-transformers'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'true'
backend: "transformers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "11"
cuda-minor-version: "7"
@@ -89,7 +102,7 @@ jobs:
context: "./backend"
- build-type: 'l4t'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-l4t-diffusers'
@@ -111,18 +124,6 @@ jobs:
backend: "diffusers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-chatterbox'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'true'
backend: "chatterbox"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
# CUDA 11 additional backends
- build-type: 'cublas'
cuda-major-version: "11"
@@ -187,7 +188,7 @@ jobs:
# CUDA 12 builds
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-rerankers'
@@ -199,7 +200,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-llama-cpp'
@@ -211,7 +212,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-vllm'
@@ -223,7 +224,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-transformers'
@@ -235,20 +236,20 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-diffusers'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
skip-drivers: 'false'
backend: "diffusers"
backend: "diffusers"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
# CUDA 12 additional backends
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-kokoro'
@@ -260,7 +261,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-faster-whisper'
@@ -272,7 +273,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-coqui'
@@ -284,7 +285,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-bark'
@@ -296,7 +297,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-chatterbox'
@@ -578,7 +579,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'true'
tag-latest: 'auto'
@@ -615,7 +616,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-stablediffusion-ggml'
@@ -675,7 +676,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'true'
tag-latest: 'auto'
@@ -700,7 +701,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-whisper'
@@ -760,7 +761,7 @@ jobs:
context: "./"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'true'
tag-latest: 'auto'
@@ -836,7 +837,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-rfdetr'
@@ -872,7 +873,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'true'
tag-latest: 'auto'
@@ -897,7 +898,7 @@ jobs:
context: "./backend"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-exllama2'
@@ -969,41 +970,54 @@ jobs:
backend: "kitten-tts"
dockerfile: "./backend/Dockerfile.python"
context: "./backend"
backend-jobs-darwin:
transformers-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
strategy:
matrix:
include:
- backend: "diffusers"
tag-suffix: "-metal-darwin-arm64-diffusers"
build-type: "mps"
- backend: "mlx"
tag-suffix: "-metal-darwin-arm64-mlx"
build-type: "mps"
- backend: "chatterbox"
tag-suffix: "-metal-darwin-arm64-chatterbox"
build-type: "mps"
- backend: "mlx-vlm"
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
build-type: "mps"
- backend: "mlx-audio"
tag-suffix: "-metal-darwin-arm64-mlx-audio"
build-type: "mps"
- backend: "stablediffusion-ggml"
tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
build-type: "metal"
lang: "go"
- backend: "whisper"
tag-suffix: "-metal-darwin-arm64-whisper"
build-type: "metal"
lang: "go"
with:
backend: ${{ matrix.backend }}
build-type: ${{ matrix.build-type }}
backend: "transformers"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: ${{ matrix.tag-suffix }}
lang: ${{ matrix.lang || 'python' }}
use-pip: ${{ matrix.backend == 'diffusers' }}
tag-suffix: "-metal-darwin-arm64-transformers"
use-pip: true
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
diffusers-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "diffusers"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-diffusers"
use-pip: true
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
mlx-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "mlx"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-mlx"
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
mlx-vlm-darwin:
uses: ./.github/workflows/backend_build_darwin.yml
with:
backend: "mlx-vlm"
build-type: "mps"
go-version: "1.24.x"
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
runs-on: "macOS-14"
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}

View File

@@ -16,10 +16,6 @@ on:
description: 'Use pip to install dependencies'
default: false
type: boolean
lang:
description: 'Programming language (e.g. go)'
default: 'python'
type: string
go-version:
description: 'Go version to use'
default: '1.24.x'
@@ -53,26 +49,26 @@ jobs:
uses: actions/checkout@v5
with:
submodules: true
- name: Setup Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: false
# You can test your matrix by printing the current Go version
- name: Display Go version
run: go version
- name: Dependencies
run: |
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
- name: Build ${{ inputs.backend }}-darwin
run: |
make protogen-go
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-python-backend
- name: Upload ${{ inputs.backend }}.tar
uses: actions/upload-artifact@v4
with:
@@ -89,20 +85,20 @@ jobs:
with:
name: ${{ inputs.backend }}-tar
path: .
- name: Install crane
run: |
curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz
sudo mv crane /usr/local/bin/
- name: Log in to DockerHub
run: |
echo "${{ secrets.dockerPassword }}" | crane auth login docker.io -u "${{ secrets.dockerUsername }}" --password-stdin
- name: Log in to quay.io
run: |
echo "${{ secrets.quayPassword }}" | crane auth login quay.io -u "${{ secrets.quayUsername }}" --password-stdin
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -116,7 +112,7 @@ jobs:
flavor: |
latest=auto
suffix=${{ inputs.tag-suffix }},onlatest=true
- name: Docker meta
id: quaymeta
uses: docker/metadata-action@v5
@@ -130,13 +126,13 @@ jobs:
flavor: |
latest=auto
suffix=${{ inputs.tag-suffix }},onlatest=true
- name: Push Docker image (DockerHub)
run: |
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do
crane push ${{ inputs.backend }}.tar $tag
done
- name: Push Docker image (Quay)
run: |
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do

View File

@@ -12,9 +12,7 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }}
has-backends: ${{ steps.set-matrix.outputs.has-backends }}
has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
steps:
- name: Checkout repository
uses: actions/checkout@v5
@@ -58,21 +56,3 @@ jobs:
strategy:
fail-fast: true
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
backend-jobs-darwin:
needs: generate-matrix
uses: ./.github/workflows/backend_build_darwin.yml
if: needs.generate-matrix.outputs.has-backends-darwin == 'true'
with:
backend: ${{ matrix.backend }}
build-type: ${{ matrix.build-type }}
go-version: "1.24.x"
tag-suffix: ${{ matrix.tag-suffix }}
lang: ${{ matrix.lang || 'python' }}
use-pip: ${{ matrix.backend == 'diffusers' }}
runs-on: "macOS-14"
secrets:
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
strategy:
fail-fast: true
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix-darwin) }}

View File

@@ -21,47 +21,3 @@ jobs:
- name: Run GoReleaser
run: |
make dev-dist
launcher-build-darwin:
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for macOS ARM64
run: |
make build-launcher-darwin
ls -liah dist
- name: Upload macOS launcher artifacts
uses: actions/upload-artifact@v4
with:
name: launcher-macos
path: dist/
retention-days: 30
launcher-build-linux:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for Linux
run: |
sudo apt-get update
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
make build-launcher-linux
- name: Upload Linux launcher artifacts
uses: actions/upload-artifact@v4
with:
name: launcher-linux
path: local-ai-launcher-linux.tar.xz
retention-days: 30

View File

@@ -36,7 +36,7 @@ jobs:
include:
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-gpu-nvidia-cuda-12'

View File

@@ -91,7 +91,7 @@ jobs:
aio: "-aio-gpu-nvidia-cuda-11"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12'
@@ -144,7 +144,7 @@ jobs:
include:
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
cuda-minor-version: "0"
platforms: 'linux/arm64'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-arm64'

View File

@@ -9,4 +9,4 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@v5

View File

@@ -23,42 +23,4 @@ jobs:
version: v2.11.0
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
launcher-build-darwin:
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for macOS ARM64
run: |
make build-launcher-darwin
- name: Upload DMG to Release
uses: softprops/action-gh-release@v2
with:
files: ./dist/LocalAI.dmg
launcher-build-linux:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: 1.23
- name: Build launcher for Linux
run: |
sudo apt-get update
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
make build-launcher-linux
- name: Upload Linux launcher artifacts
uses: softprops/action-gh-release@v2
with:
files: ./local-ai-launcher-linux.tar.xz
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -10,7 +10,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v9
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9
with:
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'

2
.gitignore vendored
View File

@@ -24,7 +24,7 @@ go-bert
# LocalAI build binary
LocalAI
/local-ai
local-ai
# prevent above rules from omitting the helm chart
!charts/*
# prevent above rules from omitting the api/localai folder

View File

@@ -8,7 +8,7 @@ source:
enabled: true
name_template: '{{ .ProjectName }}-{{ .Tag }}-source'
builds:
- main: ./cmd/local-ai
-
env:
- CGO_ENABLED=0
ldflags:

View File

@@ -18,7 +18,7 @@ FROM requirements AS requirements-drivers
ARG BUILD_TYPE
ARG CUDA_MAJOR_VERSION=12
ARG CUDA_MINOR_VERSION=8
ARG CUDA_MINOR_VERSION=0
ARG SKIP_DRIVERS=false
ARG TARGETARCH
ARG TARGETVARIANT
@@ -100,10 +100,6 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
ldconfig \
; fi
RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
; fi
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
# Cuda

View File

@@ -2,7 +2,6 @@ GOCMD=go
GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
LAUNCHER_BINARY_NAME=local-ai-launcher
GORELEASER?=
@@ -91,17 +90,7 @@ build: protogen-go install-go-tools ## Build the project
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
$(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET})
rm -rf $(BINARY_NAME) || true
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai
build-launcher: ## Build the launcher application
$(info ${GREEN}I local-ai launcher build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
rm -rf $(LAUNCHER_BINARY_NAME) || true
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(LAUNCHER_BINARY_NAME) ./cmd/launcher
build-all: build build-launcher ## Build both server and launcher
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
dev-dist:
$(GORELEASER) build --snapshot --clean
@@ -170,7 +159,7 @@ prepare-e2e:
mkdir -p $(TEST_DIR)
cp -rfv $(abspath ./tests/e2e-fixtures)/gpu.yaml $(TEST_DIR)/gpu.yaml
test -e $(TEST_DIR)/ggllm-test-model.bin || wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O $(TEST_DIR)/ggllm-test-model.bin
docker build --build-arg IMAGE_TYPE=core --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg CUDA_MAJOR_VERSION=12 --build-arg CUDA_MINOR_VERSION=8 -t localai-tests .
docker build --build-arg IMAGE_TYPE=core --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg CUDA_MAJOR_VERSION=12 --build-arg CUDA_MINOR_VERSION=0 -t localai-tests .
run-e2e-image:
ls -liah $(abspath ./tests/e2e-fixtures)
@@ -369,9 +358,6 @@ backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
backends/kokoro: docker-build-kokoro docker-save-kokoro build
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
backends/chatterbox: docker-build-chatterbox docker-save-chatterbox build
./local-ai backends install "ocifile://$(abspath ./backend-images/chatterbox.tar)"
backends/llama-cpp-darwin: build
bash ./scripts/build/llama-cpp-darwin.sh
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
@@ -379,9 +365,6 @@ backends/llama-cpp-darwin: build
build-darwin-python-backend: build
bash ./scripts/build/python-darwin.sh
build-darwin-go-backend: build
bash ./scripts/build/golang-darwin.sh
backends/mlx:
BACKEND=mlx $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)"
@@ -394,14 +377,6 @@ backends/mlx-vlm:
BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
backends/mlx-audio:
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
backends/stablediffusion-ggml-darwin:
BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
backend-images:
mkdir -p backend-images
@@ -496,7 +471,7 @@ docker-build-bark:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
docker-build-chatterbox:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox ./backend
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
docker-build-exllama2:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
@@ -532,19 +507,3 @@ docs-clean:
.PHONY: docs
docs: docs/static/gallery.html
cd docs && hugo serve
########################################################
## Platform-specific builds
########################################################
## fyne cross-platform build
build-launcher-darwin: build-launcher
go run github.com/tiagomelo/macos-dmg-creator/cmd/createdmg@latest \
--appName "LocalAI" \
--appBinaryPath "$(LAUNCHER_BINARY_NAME)" \
--bundleIdentifier "com.localai.launcher" \
--iconPath "core/http/static/logo.png" \
--outputDir "dist/"
build-launcher-linux:
cd cmd/launcher && go run fyne.io/tools/cmd/fyne@latest package -os linux -icon ../../core/http/static/logo.png --executable $(LAUNCHER_BINARY_NAME)-linux && mv launcher.tar.xz ../../$(LAUNCHER_BINARY_NAME)-linux.tar.xz

View File

@@ -110,12 +110,6 @@ curl https://localai.io/install.sh | sh
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
### macOS Download:
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
</a>
Or run with docker:
### CPU only image:
@@ -239,60 +233,6 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
- 🔊 Voice activity detection (Silero-VAD support)
- 🌍 Integrated WebUI!
## 🧩 Supported Backends & Acceleration
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
### Text Generation & Language Models
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **llama.cpp** | LLM inference in C/C++ | CUDA 11/12, ROCm, Intel SYCL, Vulkan, Metal, CPU |
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12, ROCm, Intel |
| **transformers** | HuggingFace transformers framework | CUDA 11/12, ROCm, Intel, CPU |
| **exllama2** | GPTQ inference library | CUDA 12 |
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
### Audio & Speech Processing
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12, ROCm, Intel SYCL, Vulkan, CPU |
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12, ROCm, Intel, CPU |
| **bark** | Text-to-audio generation | CUDA 12, ROCm, Intel |
| **bark-cpp** | C++ implementation of Bark | CUDA, Metal, CPU |
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12, ROCm, Intel, CPU |
| **kokoro** | Lightweight TTS model | CUDA 12, ROCm, Intel, CPU |
| **chatterbox** | Production-grade TTS | CUDA 11/12, CPU |
| **piper** | Fast neural TTS system | CPU |
| **kitten-tts** | Kitten TTS models | CPU |
| **silero-vad** | Voice Activity Detection | CPU |
### Image & Video Generation
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12, Intel SYCL, Vulkan, CPU |
| **diffusers** | HuggingFace diffusion models | CUDA 11/12, ROCm, Intel, Metal, CPU |
### Specialized AI Tasks
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **rfdetr** | Real-time object detection | CUDA 12, Intel, CPU |
| **rerankers** | Document reranking API | CUDA 11/12, ROCm, Intel, CPU |
| **local-store** | Vector database | CPU |
| **huggingface** | HuggingFace API integration | API-based |
### Hardware Acceleration Matrix
| Acceleration Type | Supported Backends | Hardware Support |
|-------------------|-------------------|------------------|
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark | AMD Graphics |
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
| **NVIDIA Jetson** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI |
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
### 🔗 Community and integrations
@@ -307,9 +247,6 @@ WebUIs:
Model galleries
- https://github.com/go-skynet/model-gallery
Voice:
- https://github.com/richiejp/VoxInput
Other:
- Helm chart https://github.com/go-skynet/helm-charts
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin

View File

@@ -1,213 +0,0 @@
# LocalAI Backend Architecture
This directory contains the core backend infrastructure for LocalAI, including the gRPC protocol definition, multi-language Dockerfiles, and language-specific backend implementations.
## Overview
LocalAI uses a unified gRPC-based architecture that allows different programming languages to implement AI backends while maintaining consistent interfaces and capabilities. The backend system supports multiple hardware acceleration targets and provides a standardized way to integrate various AI models and frameworks.
## Architecture Components
### 1. Protocol Definition (`backend.proto`)
The `backend.proto` file defines the gRPC service interface that all backends must implement. This ensures consistency across different language implementations and provides a contract for communication between LocalAI core and backend services.
#### Core Services
- **Text Generation**: `Predict`, `PredictStream` for LLM inference
- **Embeddings**: `Embedding` for text vectorization
- **Image Generation**: `GenerateImage` for stable diffusion and image models
- **Audio Processing**: `AudioTranscription`, `TTS`, `SoundGeneration`
- **Video Generation**: `GenerateVideo` for video synthesis
- **Object Detection**: `Detect` for computer vision tasks
- **Vector Storage**: `StoresSet`, `StoresGet`, `StoresFind` for RAG operations
- **Reranking**: `Rerank` for document relevance scoring
- **Voice Activity Detection**: `VAD` for audio segmentation
#### Key Message Types
- **`PredictOptions`**: Comprehensive configuration for text generation
- **`ModelOptions`**: Model loading and configuration parameters
- **`Result`**: Standardized response format
- **`StatusResponse`**: Backend health and memory usage information
### 2. Multi-Language Dockerfiles
The backend system provides language-specific Dockerfiles that handle the build environment and dependencies for different programming languages:
- `Dockerfile.python`
- `Dockerfile.golang`
- `Dockerfile.llama-cpp`
### 3. Language-Specific Implementations
#### Python Backends (`python/`)
- **transformers**: Hugging Face Transformers framework
- **vllm**: High-performance LLM inference
- **mlx**: Apple Silicon optimization
- **diffusers**: Stable Diffusion models
- **Audio**: bark, coqui, faster-whisper, kitten-tts
- **Vision**: mlx-vlm, rfdetr
- **Specialized**: rerankers, chatterbox, kokoro
#### Go Backends (`go/`)
- **whisper**: OpenAI Whisper speech recognition in Go with GGML cpp backend (whisper.cpp)
- **stablediffusion-ggml**: Stable Diffusion in Go with GGML Cpp backend
- **huggingface**: Hugging Face model integration
- **piper**: Text-to-speech synthesis Golang with C bindings using rhaspy/piper
- **bark-cpp**: Bark TTS models Golang with Cpp bindings
- **local-store**: Vector storage backend
#### C++ Backends (`cpp/`)
- **llama-cpp**: Llama.cpp integration
- **grpc**: GRPC utilities and helpers
## Hardware Acceleration Support
### CUDA (NVIDIA)
- **Versions**: CUDA 11.x, 12.x
- **Features**: cuBLAS, cuDNN, TensorRT optimization
- **Targets**: x86_64, ARM64 (Jetson)
### ROCm (AMD)
- **Features**: HIP, rocBLAS, MIOpen
- **Targets**: AMD GPUs with ROCm support
### Intel
- **Features**: oneAPI, Intel Extension for PyTorch
- **Targets**: Intel GPUs, XPUs, CPUs
### Vulkan
- **Features**: Cross-platform GPU acceleration
- **Targets**: Windows, Linux, Android, macOS
### Apple Silicon
- **Features**: MLX framework, Metal Performance Shaders
- **Targets**: M1/M2/M3 Macs
## Backend Registry (`index.yaml`)
The `index.yaml` file serves as a central registry for all available backends, providing:
- **Metadata**: Name, description, license, icons
- **Capabilities**: Hardware targets and optimization profiles
- **Tags**: Categorization for discovery
- **URLs**: Source code and documentation links
## Building Backends
### Prerequisites
- Docker with multi-architecture support
- Appropriate hardware drivers (CUDA, ROCm, etc.)
- Build tools (make, cmake, compilers)
### Build Commands
Example of build commands with Docker
```bash
# Build Python backend
docker build -f backend/Dockerfile.python \
--build-arg BACKEND=transformers \
--build-arg BUILD_TYPE=cublas12 \
--build-arg CUDA_MAJOR_VERSION=12 \
--build-arg CUDA_MINOR_VERSION=8 \
-t localai-backend-transformers .
# Build Go backend
docker build -f backend/Dockerfile.golang \
--build-arg BACKEND=whisper \
--build-arg BUILD_TYPE=cpu \
-t localai-backend-whisper .
# Build C++ backend
docker build -f backend/Dockerfile.llama-cpp \
--build-arg BACKEND=llama-cpp \
--build-arg BUILD_TYPE=cublas12 \
-t localai-backend-llama-cpp .
```
For ARM64/Mac builds, docker can't be used, and the makefile in the respective backend has to be used.
### Build Types
- **`cpu`**: CPU-only optimization
- **`cublas11`**: CUDA 11.x with cuBLAS
- **`cublas12`**: CUDA 12.x with cuBLAS
- **`hipblas`**: ROCm with rocBLAS
- **`intel`**: Intel oneAPI optimization
- **`vulkan`**: Vulkan-based acceleration
- **`metal`**: Apple Metal optimization
## Backend Development
### Creating a New Backend
1. **Choose Language**: Select Python, Go, or C++ based on requirements
2. **Implement Interface**: Implement the gRPC service defined in `backend.proto`
3. **Add Dependencies**: Create appropriate requirements files
4. **Configure Build**: Set up Dockerfile and build scripts
5. **Register Backend**: Add entry to `index.yaml`
6. **Test Integration**: Verify gRPC communication and functionality
### Backend Structure
```
backend-name/
├── backend.py/go/cpp # Main implementation
├── requirements.txt # Dependencies
├── Dockerfile # Build configuration
├── install.sh # Installation script
├── run.sh # Execution script
├── test.sh # Test script
└── README.md # Backend documentation
```
### Required gRPC Methods
At minimum, backends must implement:
- `Health()` - Service health check
- `LoadModel()` - Model loading and initialization
- `Predict()` - Main inference endpoint
- `Status()` - Backend status and metrics
## Integration with LocalAI Core
Backends communicate with LocalAI core through gRPC:
1. **Service Discovery**: Core discovers available backends
2. **Model Loading**: Core requests model loading via `LoadModel`
3. **Inference**: Core sends requests via `Predict` or specialized endpoints
4. **Streaming**: Core handles streaming responses for real-time generation
5. **Monitoring**: Core tracks backend health and performance
## Performance Optimization
### Memory Management
- **Model Caching**: Efficient model loading and caching
- **Batch Processing**: Optimize for multiple concurrent requests
- **Memory Pinning**: GPU memory optimization for CUDA/ROCm
### Hardware Utilization
- **Multi-GPU**: Support for tensor parallelism
- **Mixed Precision**: FP16/BF16 for memory efficiency
- **Kernel Fusion**: Optimized CUDA/ROCm kernels
## Troubleshooting
### Common Issues
1. **GRPC Connection**: Verify backend service is running and accessible
2. **Model Loading**: Check model paths and dependencies
3. **Hardware Detection**: Ensure appropriate drivers and libraries
4. **Memory Issues**: Monitor GPU memory usage and model sizes
## Contributing
When contributing to the backend system:
1. **Follow Protocol**: Implement the exact gRPC interface
2. **Add Tests**: Include comprehensive test coverage
3. **Document**: Provide clear usage examples
4. **Optimize**: Consider performance and resource usage
5. **Validate**: Test across different hardware targets

View File

@@ -242,7 +242,7 @@ message ModelOptions {
string Type = 49;
string FlashAttention = 56;
bool FlashAttention = 56;
bool NoKVOffload = 57;
string ModelPath = 59;
@@ -276,7 +276,6 @@ message TranscriptRequest {
string language = 3;
uint32 threads = 4;
bool translate = 5;
bool diarize = 6;
}
message TranscriptResult {
@@ -306,24 +305,22 @@ message GenerateImageRequest {
// Diffusers
string EnableParameters = 10;
int32 CLIPSkip = 11;
// Reference images for models that support them (e.g., Flux Kontext)
repeated string ref_images = 12;
}
message GenerateVideoRequest {
string prompt = 1;
string negative_prompt = 2; // Negative prompt for video generation
string start_image = 3; // Path or base64 encoded image for the start frame
string end_image = 4; // Path or base64 encoded image for the end frame
int32 width = 5;
int32 height = 6;
int32 num_frames = 7; // Number of frames to generate
int32 fps = 8; // Frames per second
int32 seed = 9;
float cfg_scale = 10; // Classifier-free guidance scale
int32 step = 11; // Number of inference steps
string dst = 12; // Output path for the generated video
string start_image = 2; // Path or base64 encoded image for the start frame
string end_image = 3; // Path or base64 encoded image for the end frame
int32 width = 4;
int32 height = 5;
int32 num_frames = 6; // Number of frames to generate
int32 fps = 7; // Frames per second
int32 seed = 8;
float cfg_scale = 9; // Classifier-free guidance scale
string dst = 10; // Output path for the generated video
}
message TTSRequest {

View File

@@ -1,6 +1,6 @@
LLAMA_VERSION?=fe4eb4f8ec25a1239b0923f1c7f87adf5730c3e5
LLAMA_REPO?=https://github.com/JohannesGaessler/llama.cpp
LLAMA_VERSION?=710dfc465a68f7443b87d9f792cffba00ed739fe
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=
BUILD_TYPE?=

View File

@@ -304,15 +304,7 @@ static void params_parse(const backend::ModelOptions* request,
}
params.use_mlock = request->mlock();
params.use_mmap = request->mmap();
if (request->flashattention() == "on" || request->flashattention() == "enabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (request->flashattention() == "auto") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
}
params.flash_attn = request->flashattention();
params.no_kv_offload = request->nokvoffload();
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)

View File

@@ -1,6 +0,0 @@
package/
sources/
.cache/
build/
libgosd.so
stablediffusion-ggml

View File

@@ -5,11 +5,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_subdirectory(./sources/stablediffusion-ggml.cpp)
add_library(gosd MODULE gosd.cpp)
target_link_libraries(gosd PRIVATE stable-diffusion ggml)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
target_link_libraries(gosd PRIVATE stdc++fs)
endif()
target_link_libraries(gosd PRIVATE stable-diffusion ggml stdc++fs)
target_include_directories(gosd PUBLIC
stable-diffusion.cpp

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=b0179181069254389ccad604e44f17a2c25b4094
STABLEDIFFUSION_GGML_VERSION?=5900ef6605c6fbf7934239f795c13c97bc993853
CMAKE_ARGS+=-DGGML_MAX_NAME=128
@@ -29,6 +29,8 @@ else ifeq ($(BUILD_TYPE),clblas)
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DSD_HIPBLAS=ON -DGGML_HIPBLAS=ON
# If it's OSX, DO NOT embed the metal library - -DGGML_METAL_EMBED_LIBRARY=ON requires further investigation
# But if it's OSX without metal, disable it here
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DSD_VULKAN=ON -DGGML_VULKAN=ON
else ifeq ($(OS),Darwin)
@@ -72,10 +74,10 @@ libgosd.so: sources/stablediffusion-ggml.cpp CMakeLists.txt gosd.cpp gosd.h
stablediffusion-ggml: main.go gosd.go libgosd.so
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
package: stablediffusion-ggml
package:
bash package.sh
build: package
build: stablediffusion-ggml package
clean:
rm -rf libgosd.so build stablediffusion-ggml package sources
rm -rf libgosd.o build stablediffusion-ggml

View File

@@ -44,7 +44,7 @@ const char* sample_method_str[] = {
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedulers[] = {
const char* schedule_str[] = {
"default",
"discrete",
"karras",
@@ -54,8 +54,6 @@ const char* schedulers[] = {
};
sd_ctx_t* sd_c;
// Moved from the context (load time) to generation time params
scheduler_t scheduler = scheduler_t::DEFAULT;
sample_method_t sample_method;
@@ -107,7 +105,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *clip_g_path = "";
const char *t5xxl_path = "";
const char *vae_path = "";
const char *scheduler_str = "";
const char *scheduler = "";
const char *sampler = "";
char *lora_dir = model_path;
bool lora_dir_allocated = false;
@@ -135,7 +133,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
vae_path = optval;
}
if (!strcmp(optname, "scheduler")) {
scheduler_str = optval;
scheduler = optval;
}
if (!strcmp(optname, "sampler")) {
sampler = optval;
@@ -172,13 +170,22 @@ int load_model(const char *model, char *model_path, char* options[], int threads
}
sample_method = (sample_method_t)sample_method_found;
int schedule_found = -1;
for (int d = 0; d < SCHEDULE_COUNT; d++) {
if (!strcmp(scheduler_str, schedulers[d])) {
scheduler = (scheduler_t)d;
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
if (!strcmp(scheduler, schedule_str[d])) {
schedule_found = d;
fprintf (stderr, "Found scheduler: %s\n", scheduler);
}
}
if (schedule_found == -1) {
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
schedule_found = DEFAULT;
}
schedule_t schedule = (schedule_t)schedule_found;
fprintf (stderr, "Creating context\n");
sd_ctx_params_t ctx_params;
sd_ctx_params_init(&ctx_params);
@@ -198,6 +205,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.free_params_immediately = false;
ctx_params.n_threads = threads;
ctx_params.rng_type = STD_DEFAULT_RNG;
ctx_params.schedule = schedule;
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
if (sd_ctx == NULL) {
@@ -233,16 +241,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
p.prompt = text;
p.negative_prompt = negativeText;
p.sample_params.guidance.txt_cfg = cfg_scale;
p.sample_params.guidance.slg.layers = skip_layers.data();
p.sample_params.guidance.slg.layer_count = skip_layers.size();
p.guidance.txt_cfg = cfg_scale;
p.guidance.slg.layers = skip_layers.data();
p.guidance.slg.layer_count = skip_layers.size();
p.width = width;
p.height = height;
p.sample_params.sample_method = sample_method;
p.sample_params.sample_steps = steps;
p.sample_method = sample_method;
p.sample_steps = steps;
p.seed = seed;
p.input_id_images_path = "";
p.sample_params.scheduler = scheduler;
// Handle input image for img2img
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);

View File

@@ -10,9 +10,9 @@ CURDIR=$(dirname "$(realpath $0)")
# Create lib directory
mkdir -p $CURDIR/package/lib
cp -avf $CURDIR/libgosd.so $CURDIR/package/
cp -avf $CURDIR/stablediffusion-ggml $CURDIR/package/
cp -fv $CURDIR/run.sh $CURDIR/package/
cp -avrf $CURDIR/libgosd.so $CURDIR/package/
cp -avrf $CURDIR/stablediffusion-ggml $CURDIR/package/
cp -rfv $CURDIR/run.sh $CURDIR/package/
# Detect architecture and copy appropriate libraries
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
@@ -43,8 +43,6 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ $(uname -s) = "Darwin" ]; then
echo "Detected Darwin"
else
echo "Error: Could not detect architecture"
exit 1

View File

@@ -1,7 +0,0 @@
.cache/
sources/
build/
package/
whisper
libgowhisper.so

View File

@@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.12)
project(gowhisper LANGUAGES C CXX)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
add_subdirectory(./sources/whisper.cpp)
add_library(gowhisper MODULE gowhisper.cpp)
target_link_libraries(gowhisper PRIVATE whisper ggml)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
target_link_libraries(gosd PRIVATE stdc++fs)
endif()
set_property(TARGET gowhisper PROPERTY CXX_STANDARD 17)
set_target_properties(gowhisper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

View File

@@ -1,53 +1,110 @@
CMAKE_ARGS?=
BUILD_TYPE?=
GOCMD=go
NATIVE?=false
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc --ignore=1)
BUILD_TYPE?=
CMAKE_ARGS?=
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=edea8a9c3cf0eb7676dcdb604991eb2f95c3d984
WHISPER_CPP_VERSION?=fc45bb86251f774ef817e89878bb4c2636c8a58f
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
export WHISPER_INCLUDE_PATH=$(WHISPER_DIR)/include:$(WHISPER_DIR)/ggml/include
export WHISPER_LIBRARY_PATH=$(WHISPER_DIR)/build/src/:$(WHISPER_DIR)/build/ggml/src
CGO_LDFLAGS_WHISPER?=
CGO_LDFLAGS_WHISPER+=-lggml
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
CUDA_LIBPATH?=/usr/local/cuda/lib64/
ONEAPI_VERSION?=2025.2
# IF native is false, we add -DGGML_NATIVE=OFF to CMAKE_ARGS
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
WHISPER_CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
ifeq ($(BUILD_TYPE),cublas)
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH) -L$(CUDA_LIBPATH)/stubs/ -lcuda
CMAKE_ARGS+=-DGGML_CUDA=ON
CGO_LDFLAGS_WHISPER+=-lcufft -lggml-cuda
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-cuda/
# If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
# to CMAKE_ARGS automatically
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
else ifeq ($(BUILD_TYPE),clblas)
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
ROCM_HOME ?= /opt/rocm
ROCM_PATH ?= /opt/rocm
LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib
export STABLE_BUILD_TYPE=
export CXX=$(ROCM_HOME)/llvm/bin/clang++
export CC=$(ROCM_HOME)/llvm/bin/clang
# GPU_TARGETS ?= gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102
# AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
CMAKE_ARGS+=-DGGML_HIP=ON
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib -L$(CURRENT_MAKEFILE_DIR)/sources/whisper.cpp/build/ggml/src/ggml-hip/ -lggml-hip
# CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DGGML_VULKAN=ON
CMAKE_ARGS+=-DGGML_VULKAN=1
CGO_LDFLAGS_WHISPER+=-lggml-vulkan -lvulkan
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-vulkan/
else ifeq ($(OS),Darwin)
ifeq ($(BUILD_TYPE),)
BUILD_TYPE=metal
endif
ifneq ($(BUILD_TYPE),metal)
CMAKE_ARGS+=-DGGML_METAL=OFF
CGO_LDFLAGS_WHISPER+=-lggml-blas
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-blas
else
CMAKE_ARGS+=-DGGML_METAL=ON
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
CMAKE_ARGS+=-DGGML_METAL_USE_BF16=ON
CMAKE_ARGS+=-DGGML_OPENMP=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_EXAMPLES=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_TESTS=OFF
CMAKE_ARGS+=-DWHISPER_BUILD_SERVER=OFF
CGO_LDFLAGS += -framework Accelerate
CGO_LDFLAGS_WHISPER+=-lggml-metal -lggml-blas
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-metal/:$(WHISPER_DIR)/build/ggml/src/ggml-blas
endif
TARGET+=--target ggml-metal
endif
ifeq ($(BUILD_TYPE),sycl_f16)
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
export CC=icx
export CXX=icpx
CGO_LDFLAGS_WHISPER += -fsycl -L${DNNLROOT}/lib -rpath ${ONEAPI_ROOT}/${ONEAPI_VERSION}/lib -ldnnl ${MKLROOT}/lib/intel64/libmkl_sycl.a -fiopenmp -fopenmp-targets=spir64 -lOpenCL -lggml-sycl
CGO_LDFLAGS_WHISPER += $(shell pkg-config --libs mkl-static-lp64-gomp)
CGO_CXXFLAGS_WHISPER += -fiopenmp -fopenmp-targets=spir64
CGO_CXXFLAGS_WHISPER += $(shell pkg-config --cflags mkl-static-lp64-gomp )
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-sycl/
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
-DGGML_SYCL_F16=ON
-DCMAKE_CXX_FLAGS="-fsycl"
endif
ifeq ($(BUILD_TYPE),sycl_f32)
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
ifeq ($(BUILD_TYPE),sycl_f16)
CMAKE_ARGS+=-DGGML_SYCL_F16=ON
endif
ifneq ($(OS),Darwin)
CGO_LDFLAGS_WHISPER+=-lgomp
endif
## whisper
sources/whisper.cpp:
mkdir -p sources/whisper.cpp
cd sources/whisper.cpp && \
@@ -57,21 +114,18 @@ sources/whisper.cpp:
git checkout $(WHISPER_CPP_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch
libgowhisper.so: sources/whisper.cpp CMakeLists.txt gowhisper.cpp gowhisper.h
mkdir -p build && \
cd build && \
cmake .. $(CMAKE_ARGS) && \
cmake --build . --config Release -j$(JOBS) && \
cd .. && \
mv build/libgowhisper.so ./
sources/whisper.cpp/build/src/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && cmake $(CMAKE_ARGS) $(WHISPER_CMAKE_ARGS) . -B ./build
cd sources/whisper.cpp/build && cmake --build . --config Release
whisper: main.go gowhisper.go libgowhisper.so
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
whisper: sources/whisper.cpp sources/whisper.cpp/build/src/libwhisper.a
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(CURDIR)/sources/whisper.cpp/bindings/go
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="${WHISPER_INCLUDE_PATH}" LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" LD_LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" \
CGO_CXXFLAGS="$(CGO_CXXFLAGS_WHISPER)" \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o whisper ./
package: whisper
package:
bash package.sh
build: package
clean:
rm -rf libgowhisper.o build whisper
build: whisper package

View File

@@ -1,154 +0,0 @@
#include "gowhisper.h"
#include "ggml-backend.h"
#include "whisper.h"
#include <vector>
static struct whisper_vad_context *vctx;
static struct whisper_context *ctx;
static std::vector<float> flat_segs;
static void ggml_log_cb(enum ggml_log_level level, const char *log,
void *data) {
const char *level_str;
if (!log) {
return;
}
switch (level) {
case GGML_LOG_LEVEL_DEBUG:
level_str = "DEBUG";
break;
case GGML_LOG_LEVEL_INFO:
level_str = "INFO";
break;
case GGML_LOG_LEVEL_WARN:
level_str = "WARN";
break;
case GGML_LOG_LEVEL_ERROR:
level_str = "ERROR";
break;
default: /* Potential future-proofing */
level_str = "?????";
break;
}
fprintf(stderr, "[%-5s] ", level_str);
fputs(log, stderr);
fflush(stderr);
}
int load_model(const char *const model_path) {
whisper_log_set(ggml_log_cb, nullptr);
ggml_backend_load_all();
struct whisper_context_params cparams = whisper_context_default_params();
ctx = whisper_init_from_file_with_params(model_path, cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: Also failed to init model as transcriber\n");
return 1;
}
return 0;
}
int load_model_vad(const char *const model_path) {
whisper_log_set(ggml_log_cb, nullptr);
ggml_backend_load_all();
struct whisper_vad_context_params vcparams =
whisper_vad_default_context_params();
// XXX: Overridden to false in upstream due to performance?
// vcparams.use_gpu = true;
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
if (vctx == nullptr) {
fprintf(stderr, "error: Failed to init model as VAD\n");
return 1;
}
return 0;
}
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
size_t *segs_out_len) {
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
fprintf(stderr, "error: failed to detect speech\n");
return 1;
}
struct whisper_vad_params params = whisper_vad_default_params();
struct whisper_vad_segments *segs =
whisper_vad_segments_from_probs(vctx, params);
size_t segn = whisper_vad_segments_n_segments(segs);
// fprintf(stderr, "Got segments %zd\n", segn);
flat_segs.clear();
for (int i = 0; i < segn; i++) {
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
}
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
// flat_segs.size());
*segs_out = flat_segs.data();
*segs_out_len = flat_segs.size();
// fprintf(stderr, "freeing segs\n");
whisper_vad_free_segments(segs);
// fprintf(stderr, "returning\n");
return 0;
}
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
whisper_full_params wparams =
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.n_threads = threads;
if (*lang != '\0')
wparams.language = lang;
else {
wparams.language = nullptr;
}
wparams.translate = translate;
wparams.debug_mode = true;
wparams.print_progress = true;
wparams.tdrz_enable = tdrz;
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
fprintf(stderr, "error: transcription failed\n");
return 1;
}
*segs_out_len = whisper_full_n_segments(ctx);
return 0;
}
const char *get_segment_text(int i) {
return whisper_full_get_segment_text(ctx, i);
}
int64_t get_segment_t0(int i) { return whisper_full_get_segment_t0(ctx, i); }
int64_t get_segment_t1(int i) { return whisper_full_get_segment_t1(ctx, i); }
int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
int32_t get_token_id(int i, int j) {
return whisper_full_get_token_id(ctx, i, j);
}
bool get_segment_speaker_turn_next(int i) {
return whisper_full_get_segment_speaker_turn_next(ctx, i);
}

View File

@@ -1,161 +0,0 @@
package main
import (
"fmt"
"os"
"path/filepath"
"strings"
"unsafe"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
var (
CppLoadModel func(modelPath string) int
CppLoadModelVAD func(modelPath string) int
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
CppGetSegmentText func(i int) string
CppGetSegmentStart func(i int) int64
CppGetSegmentEnd func(i int) int64
CppNTokens func(i int) int
CppGetTokenID func(i int, j int) int
CppGetSegmentSpeakerTurnNext func(i int) bool
)
type Whisper struct {
base.SingleThread
}
func (w *Whisper) Load(opts *pb.ModelOptions) error {
vadOnly := false
for _, oo := range opts.Options {
if oo == "vad_only" {
vadOnly = true
} else {
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
}
}
if vadOnly {
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
return fmt.Errorf("Failed to load Whisper VAD model")
}
return nil
}
if ret := CppLoadModel(opts.ModelFile); ret != 0 {
return fmt.Errorf("Failed to load Whisper transcription model")
}
return nil
}
func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
audio := req.Audio
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
}
// Happens when CPP vector has not had any elements pushed to it
if segsPtr == 0 {
return pb.VADResponse{
Segments: []*pb.VADSegment{},
}, nil
}
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen)
vadSegments := []*pb.VADSegment{}
for i := range len(segs) >> 1 {
s := segs[2*i] / 100
t := segs[2*i+1] / 100
vadSegments = append(vadSegments, &pb.VADSegment{
Start: s,
End: t,
})
}
return pb.VADResponse{
Segments: vadSegments,
}, nil
}
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}
data := buf.AsFloat32Buffer().Data
segsLen := uintptr(0xdeadbeef)
segsLenPtr := unsafe.Pointer(&segsLen)
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
}
segments := []*pb.TranscriptSegment{}
text := ""
for i := range int(segsLen) {
s := CppGetSegmentStart(i)
t := CppGetSegmentEnd(i)
txt := strings.Clone(CppGetSegmentText(i))
tokens := make([]int32, CppNTokens(i))
if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
txt += " [SPEAKER_TURN]"
}
for j := range tokens {
tokens[j] = int32(CppGetTokenID(i, j))
}
segment := &pb.TranscriptSegment{
Id: int32(i),
Text: txt,
Start: s, End: t,
Tokens: tokens,
}
segments = append(segments, segment)
text += " " + strings.TrimSpace(txt)
}
return pb.TranscriptResult{
Segments: segments,
Text: strings.TrimSpace(text),
}, nil
}

View File

@@ -1,17 +0,0 @@
#include <cstddef>
#include <cstdint>
extern "C" {
int load_model(const char *const model_path);
int load_model_vad(const char *const model_path);
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
size_t *segs_out_len);
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
const char *get_segment_text(int i);
int64_t get_segment_t0(int i);
int64_t get_segment_t1(int i);
int n_tokens(int i);
int32_t get_token_id(int i, int j);
bool get_segment_speaker_turn_next(int i);
}

View File

@@ -1,10 +1,10 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
@@ -12,34 +12,7 @@ var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
func main() {
gosd, err := purego.Dlopen("./libgowhisper.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
panic(err)
}
libFuncs := []LibFuncs{
{&CppLoadModel, "load_model"},
{&CppLoadModelVAD, "load_model_vad"},
{&CppVAD, "vad"},
{&CppTranscribe, "transcribe"},
{&CppGetSegmentText, "get_segment_text"},
{&CppGetSegmentStart, "get_segment_t0"},
{&CppGetSegmentEnd, "get_segment_t1"},
{&CppNTokens, "n_tokens"},
{&CppGetTokenID, "get_token_id"},
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
}
flag.Parse()
if err := grpc.StartServer(*addr, &Whisper{}); err != nil {

View File

@@ -10,8 +10,8 @@ CURDIR=$(dirname "$(realpath $0)")
# Create lib directory
mkdir -p $CURDIR/package/lib
cp -avf $CURDIR/whisper $CURDIR/libgowhisper.so $CURDIR/package/
cp -fv $CURDIR/run.sh $CURDIR/package/
cp -avrf $CURDIR/whisper $CURDIR/package/
cp -rfv $CURDIR/run.sh $CURDIR/package/
# Detect architecture and copy appropriate libraries
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
@@ -42,13 +42,11 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ $(uname -s) = "Darwin" ]; then
echo "Detected Darwin"
else
echo "Error: Could not detect architecture"
exit 1
fi
echo "Packaging completed successfully"
echo "Packaging completed successfully"
ls -liah $CURDIR/package/
ls -liah $CURDIR/package/lib/
ls -liah $CURDIR/package/lib/

View File

@@ -0,0 +1,105 @@
package main
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"os"
"path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
type Whisper struct {
base.SingleThread
whisper whisper.Model
}
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.ModelFile)
sd.whisper = w
return err
}
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}
data := buf.AsFloat32Buffer().Data
// Process samples
context, err := sd.whisper.NewContext()
if err != nil {
return pb.TranscriptResult{}, err
}
context.SetThreads(uint(opts.Threads))
if opts.Language != "" {
context.SetLanguage(opts.Language)
} else {
context.SetLanguage("auto")
}
if opts.Translate {
context.SetTranslate(true)
}
if err := context.Process(data, nil, nil, nil); err != nil {
return pb.TranscriptResult{}, err
}
segments := []*pb.TranscriptSegment{}
text := ""
for {
s, err := context.NextSegment()
if err != nil {
break
}
var tokens []int32
for _, t := range s.Tokens {
tokens = append(tokens, int32(t.Id))
}
segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
segments = append(segments, segment)
text += s.Text
}
return pb.TranscriptResult{
Segments: segments,
Text: text,
}, nil
}

View File

@@ -45,7 +45,6 @@
default: "cpu-whisper"
nvidia: "cuda12-whisper"
intel: "intel-sycl-f16-whisper"
metal: "metal-whisper"
amd: "rocm-whisper"
vulkan: "vulkan-whisper"
nvidia-l4t: "nvidia-l4t-arm64-whisper"
@@ -72,7 +71,7 @@
# amd: "rocm-stablediffusion-ggml"
vulkan: "vulkan-stablediffusion-ggml"
nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml"
metal: "metal-stablediffusion-ggml"
# metal: "metal-stablediffusion-ggml"
# darwin-x86: "darwin-x86-stablediffusion-ggml"
- &rfdetr
name: "rfdetr"
@@ -148,7 +147,7 @@
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm"
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
urls:
- https://github.com/Blaizzy/mlx-vlm
- https://github.com/ml-explore/mlx-vlm
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm
license: MIT
@@ -160,23 +159,6 @@
- vision-language
- LLM
- MLX
- &mlx-audio
name: "mlx-audio"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio"
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
urls:
- https://github.com/Blaizzy/mlx-audio
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-mlx-audio
license: MIT
description: |
Run Audio Models with MLX
tags:
- audio-to-text
- audio-generation
- text-to-audio
- LLM
- MLX
- &rerankers
name: "rerankers"
alias: "rerankers"
@@ -201,6 +183,8 @@
nvidia: "cuda12-transformers"
intel: "intel-transformers"
amd: "rocm-transformers"
metal: "metal-transformers"
default: "cpu-transformers"
- &diffusers
name: "diffusers"
icon: https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg
@@ -350,8 +334,6 @@
alias: "chatterbox"
capabilities:
nvidia: "cuda12-chatterbox"
metal: "metal-chatterbox"
default: "cpu-chatterbox"
- &piper
name: "piper"
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
@@ -435,11 +417,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
- !!merge <<: *mlx-audio
name: "mlx-audio-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
- !!merge <<: *kitten-tts
name: "kitten-tts-development"
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
@@ -582,16 +559,6 @@
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisper"
mirrors:
- localai/localai-backends:latest-cpu-whisper
- !!merge <<: *whispercpp
name: "metal-whisper"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-whisper
- !!merge <<: *whispercpp
name: "metal-whisper-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-whisper
- !!merge <<: *whispercpp
name: "cpu-whisper-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisper"
@@ -678,16 +645,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-stablediffusion-ggml"
mirrors:
- localai/localai-backends:master-cpu-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "metal-stablediffusion-ggml"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-stablediffusion-ggml"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "metal-stablediffusion-ggml-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-stablediffusion-ggml"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "vulkan-stablediffusion-ggml"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-stablediffusion-ggml"
@@ -896,6 +853,28 @@
nvidia: "cuda12-transformers-development"
intel: "intel-transformers-development"
amd: "rocm-transformers-development"
default: "cpu-transformers-development"
metal: "metal-transformers-development"
- !!merge <<: *transformers
name: "cpu-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-transformers"
mirrors:
- localai/localai-backends:latest-cpu-transformers
- !!merge <<: *transformers
name: "cpu-transformers-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-transformers"
mirrors:
- localai/localai-backends:master-cpu-transformers
- !!merge <<: *transformers
name: "metal-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-transformers"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-transformers
- !!merge <<: *transformers
name: "metal-transformers-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-transformers"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-transformers
- !!merge <<: *transformers
name: "cuda12-transformers"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-transformers"
@@ -1225,28 +1204,6 @@
name: "chatterbox-development"
capabilities:
nvidia: "cuda12-chatterbox-development"
metal: "metal-chatterbox-development"
default: "cpu-chatterbox-development"
- !!merge <<: *chatterbox
name: "cpu-chatterbox"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
mirrors:
- localai/localai-backends:latest-cpu-chatterbox
- !!merge <<: *chatterbox
name: "cpu-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
mirrors:
- localai/localai-backends:master-cpu-chatterbox
- !!merge <<: *chatterbox
name: "metal-chatterbox"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-chatterbox
- !!merge <<: *chatterbox
name: "metal-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-chatterbox
- !!merge <<: *chatterbox
name: "cuda12-chatterbox-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"

View File

@@ -1,190 +1,38 @@
# Python Backends for LocalAI
# Common commands about conda environment
This directory contains Python-based AI backends for LocalAI, providing support for various AI models and hardware acceleration targets.
## Create a new empty conda environment
## Overview
```
conda create --name <env-name> python=<your version> -y
The Python backends use a unified build system based on `libbackend.sh` that provides:
- **Automatic virtual environment management** with support for both `uv` and `pip`
- **Hardware-specific dependency installation** (CPU, CUDA, Intel, MLX, etc.)
- **Portable Python support** for standalone deployments
- **Consistent backend execution** across different environments
## Available Backends
### Core AI Models
- **transformers** - Hugging Face Transformers framework (PyTorch-based)
- **vllm** - High-performance LLM inference engine
- **mlx** - Apple Silicon optimized ML framework
- **exllama2** - ExLlama2 quantized models
### Audio & Speech
- **bark** - Text-to-speech synthesis
- **coqui** - Coqui TTS models
- **faster-whisper** - Fast Whisper speech recognition
- **kitten-tts** - Lightweight TTS
- **mlx-audio** - Apple Silicon audio processing
- **chatterbox** - TTS model
- **kokoro** - TTS models
### Computer Vision
- **diffusers** - Stable Diffusion and image generation
- **mlx-vlm** - Vision-language models for Apple Silicon
- **rfdetr** - Object detection models
### Specialized
- **rerankers** - Text reranking models
## Quick Start
### Prerequisites
- Python 3.10+ (default: 3.10.18)
- `uv` package manager (recommended) or `pip`
- Appropriate hardware drivers for your target (CUDA, Intel, etc.)
### Installation
Each backend can be installed individually:
```bash
# Navigate to a specific backend
cd backend/python/transformers
# Install dependencies
make transformers
# or
bash install.sh
# Run the backend
make run
# or
bash run.sh
conda create --name autogptq python=3.11 -y
```
### Using the Unified Build System
## To activate the environment
The `libbackend.sh` script provides consistent commands across all backends:
```bash
# Source the library in your backend script
source $(dirname $0)/../common/libbackend.sh
# Install requirements (automatically handles hardware detection)
installRequirements
# Start the backend server
startBackend $@
# Run tests
runUnittests
As of conda 4.4
```
conda activate autogptq
```
## Hardware Targets
The conda version older than 4.4
The build system automatically detects and configures for different hardware:
- **CPU** - Standard CPU-only builds
- **CUDA** - NVIDIA GPU acceleration (supports CUDA 11/12)
- **Intel** - Intel XPU/GPU optimization
- **MLX** - Apple Silicon (M1/M2/M3) optimization
- **HIP** - AMD GPU acceleration
### Target-Specific Requirements
Backends can specify hardware-specific dependencies:
- `requirements.txt` - Base requirements
- `requirements-cpu.txt` - CPU-specific packages
- `requirements-cublas11.txt` - CUDA 11 packages
- `requirements-cublas12.txt` - CUDA 12 packages
- `requirements-intel.txt` - Intel-optimized packages
- `requirements-mps.txt` - Apple Silicon packages
## Configuration Options
### Environment Variables
- `PYTHON_VERSION` - Python version (default: 3.10)
- `PYTHON_PATCH` - Python patch version (default: 18)
- `BUILD_TYPE` - Force specific build target
- `USE_PIP` - Use pip instead of uv (default: false)
- `PORTABLE_PYTHON` - Enable portable Python builds
- `LIMIT_TARGETS` - Restrict backend to specific targets
### Example: CUDA 12 Only Backend
```bash
# In your backend script
LIMIT_TARGETS="cublas12"
source $(dirname $0)/../common/libbackend.sh
```
source activate autogptq
```
### Example: Intel-Optimized Backend
## Install the packages to your environment
```bash
# In your backend script
LIMIT_TARGETS="intel"
source $(dirname $0)/../common/libbackend.sh
Sometimes you need to install the packages from the conda-forge channel
By using `conda`
```
conda install <your-package-name>
conda install -c conda-forge <your package-name>
```
## Development
### Adding a New Backend
1. Create a new directory in `backend/python/`
2. Copy the template structure from `common/template/`
3. Implement your `backend.py` with the required gRPC interface
4. Add appropriate requirements files for your target hardware
5. Use `libbackend.sh` for consistent build and execution
### Testing
```bash
# Run backend tests
make test
# or
bash test.sh
Or by using `pip`
```
### Building
```bash
# Install dependencies
make <backend-name>
# Clean build artifacts
make clean
pip install <your-package-name>
```
## Architecture
Each backend follows a consistent structure:
```
backend-name/
├── backend.py # Main backend implementation
├── requirements.txt # Base dependencies
├── requirements-*.txt # Hardware-specific dependencies
├── install.sh # Installation script
├── run.sh # Execution script
├── test.sh # Test script
├── Makefile # Build targets
└── test.py # Unit tests
```
## Troubleshooting
### Common Issues
1. **Missing dependencies**: Ensure all requirements files are properly configured
2. **Hardware detection**: Check that `BUILD_TYPE` matches your system
3. **Python version**: Verify Python 3.10+ is available
4. **Virtual environment**: Use `ensureVenv` to create/activate environments
## Contributing
When adding new backends or modifying existing ones:
1. Follow the established directory structure
2. Use `libbackend.sh` for consistent behavior
3. Include appropriate requirements files for all target hardware
4. Add comprehensive tests
5. Update this README if adding new backend types

View File

@@ -1,6 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cpu
accelerate
torch==2.6.0
torchaudio==2.6.0
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts

View File

@@ -2,5 +2,5 @@
torch==2.6.0+cu118
torchaudio==2.6.0+cu118
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -1,5 +1,5 @@
torch==2.6.0
torchaudio==2.6.0
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -2,5 +2,5 @@
torch==2.6.0+rocm6.1
torchaudio==2.6.0+rocm6.1
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate

View File

@@ -3,8 +3,9 @@ intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi
torchaudio==2.3.1+cxx11.abi
transformers==4.46.3
chatterbox-tts==0.1.2
chatterbox-tts
accelerate
oneccl_bind_pt==2.3.100+xpu
optimum[openvino]
setuptools
setuptools
accelerate

View File

@@ -286,8 +286,7 @@ _makeVenvPortable() {
function ensureVenv() {
local interpreter=""
if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -e "$(_portable_python)" ]; then
echo "Using portable Python"
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
ensurePortablePython
interpreter="$(_portable_python)"
else
@@ -385,11 +384,6 @@ function installRequirements() {
requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt")
fi
# This is needed to build wheels that e.g. depends on Python.h
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
fi
for reqFile in ${requirementFiles[@]}; do
if [ -f "${reqFile}" ]; then
echo "starting requirements install for ${reqFile}"

View File

@@ -18,7 +18,7 @@ import backend_pb2_grpc
import grpc
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video
@@ -72,6 +72,13 @@ def is_float(s):
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
# Credits to https://github.com/neggles
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
@@ -177,10 +184,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
key, value = opt.split(":")
# if value is a number, convert it to the appropriate type
if is_float(value):
if value.is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif is_int(value):
value = int(value)
self.options[key] = value
# From options, extract if present "torch_dtype" and set it to the appropriate type
@@ -328,32 +334,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torch_dtype=torch.bfloat16)
self.pipe.vae.to(torch.bfloat16)
self.pipe.text_encoder.to(torch.bfloat16)
elif request.PipelineType == "WanPipeline":
# WAN2.2 pipeline requires special VAE handling
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
elif request.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video pipeline
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanImageToVideoPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.img2vid = True # WAN2.2 image-to-video pipeline
if CLIPSKIP and request.CLIPSkip != 0:
self.clip_skip = request.CLIPSkip
@@ -495,24 +475,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"num_inference_steps": steps,
}
# Handle image source: prioritize RefImages over request.src
image_src = None
if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
# Use the first reference image if available
image_src = request.ref_images[0]
print(f"Using reference image: {image_src}", file=sys.stderr)
elif request.src != "":
# Fall back to request.src if no ref_images
image_src = request.src
print(f"Using source image: {image_src}", file=sys.stderr)
else:
print("No image source provided", file=sys.stderr)
if image_src and not self.controlnet and not self.img2vid:
image = Image.open(image_src)
if request.src != "" and not self.controlnet and not self.img2vid:
image = Image.open(request.src)
options["image"] = image
elif self.controlnet and image_src:
pose_image = load_image(image_src)
elif self.controlnet and request.src:
pose_image = load_image(request.src)
options["image"] = pose_image
if CLIPSKIP and self.clip_skip != 0:
@@ -554,11 +521,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if self.img2vid:
# Load the conditioning image
if image_src:
image = load_image(image_src)
else:
# Fallback to request.src for img2vid if no ref_images
image = load_image(request.src)
image = load_image(request.src)
image = image.resize((1024, 576))
generator = torch.manual_seed(request.seed)
@@ -595,96 +558,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Media generated", success=True)
def GenerateVideo(self, request, context):
try:
prompt = request.prompt
if not prompt:
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
# Set default values from request or use defaults
num_frames = request.num_frames if request.num_frames > 0 else 81
fps = request.fps if request.fps > 0 else 16
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
num_inference_steps = request.step if request.step > 0 else 40
# Prepare generation parameters
kwargs = {
"prompt": prompt,
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
"height": request.height if request.height > 0 else 720,
"width": request.width if request.width > 0 else 1280,
"num_frames": num_frames,
"guidance_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
}
# Add custom options from self.options (including guidance_scale_2 if specified)
kwargs.update(self.options)
# Set seed if provided
if request.seed > 0:
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
# Handle start and end images for video generation
if request.start_image:
kwargs["start_image"] = load_image(request.start_image)
if request.end_image:
kwargs["end_image"] = load_image(request.end_image)
print(f"Generating video with {kwargs=}", file=sys.stderr)
# Generate video frames based on pipeline type
if self.PipelineType == "WanPipeline":
# WAN2.2 text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0] # WAN2.2 returns frames in this format
elif self.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video generation
if request.start_image:
# Load and resize the input image according to WAN2.2 requirements
image = load_image(request.start_image)
# Use request dimensions or defaults, but respect WAN2.2 constraints
request_height = request.height if request.height > 0 else 480
request_width = request.width if request.width > 0 else 832
max_area = request_height * request_width
aspect_ratio = image.height / image.width
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
image = image.resize((width, height))
kwargs["image"] = image
kwargs["height"] = height
kwargs["width"] = width
output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.img2vid:
# Generic image-to-video generation
if request.start_image:
image = load_image(request.start_image)
image = image.resize((request.width if request.width > 0 else 1024,
request.height if request.height > 0 else 576))
kwargs["image"] = image
output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.txt2vid:
# Generic text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0]
else:
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
# Export video
export_to_video(frames, request.dst, fps=fps)
return backend_pb2.Result(message="Video generated successfully", success=True)
except Exception as err:
print(f"Error generating video: {err}", file=sys.stderr)
traceback.print_exc()
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),

View File

@@ -8,5 +8,4 @@ compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,12 +1,11 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.7.1+cu118
torchvision==0.22.1+cu118
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision==0.22.1
accelerate
compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,12 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.7.1
torchvision==0.22.1
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision
accelerate
compel
peft
sentencepiece
torch
ftfy
optimum-quanto
optimum-quanto

View File

@@ -8,5 +8,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -12,5 +12,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -8,5 +8,4 @@ peft
optimum-quanto
numpy<2
sentencepiece
torchvision
ftfy
torchvision

View File

@@ -7,5 +7,4 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
ftfy
optimum-quanto

View File

@@ -1,23 +0,0 @@
.PHONY: mlx-audio
mlx-audio:
bash install.sh
.PHONY: run
run: mlx-audio
@echo "Running mlx-audio..."
bash run.sh
@echo "mlx run."
.PHONY: test
test: mlx-audio
@echo "Testing mlx-audio..."
bash test.sh
@echo "mlx tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -1,459 +0,0 @@
#!/usr/bin/env python3
import asyncio
from concurrent import futures
import argparse
import signal
import sys
import os
import shutil
import glob
from typing import List
import time
import tempfile
import backend_pb2
import backend_pb2_grpc
import grpc
from mlx_audio.tts.utils import load_model
import soundfile as sf
import numpy as np
import uuid
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer that implements the Backend service defined in backend.proto.
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
"""
def _is_float(self, s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
Args:
request: The health check request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The health check reply.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
"""
Loads a TTS model using MLX-Audio.
Args:
request: The load model request.
context: The gRPC context.
Returns:
backend_pb2.Result: The load model result.
"""
try:
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
# Parse options like in the kokoro backend
options = request.Options
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We store all the options in a dict for later use
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
print(f"Options: {self.options}", file=sys.stderr)
# Load the model using MLX-Audio's load_model function
try:
self.tts_model = load_model(request.Model)
self.model_path = request.Model
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
except Exception as model_err:
print(f"Error loading TTS model: {model_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
except Exception as err:
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
def TTS(self, request, context):
"""
Generates TTS audio from text using MLX-Audio.
Args:
request: A TTSRequest object containing text, model, destination, voice, and language.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Result object indicating success or failure.
"""
try:
# Check if model is loaded
if not hasattr(self, 'tts_model') or self.tts_model is None:
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
# Handle speed parameter based on model type
speed_value = self._handle_speed_parameter(request, self.model_path)
# Map language names to codes if needed
lang_code = self._map_language_code(request.language, request.voice)
# Prepare generation parameters
gen_params = {
"text": request.text,
"speed": speed_value,
"verbose": False,
}
# Add model-specific parameters
if request.voice and request.voice.strip():
gen_params["voice"] = request.voice
# Check if model supports language codes (primarily Kokoro)
if "kokoro" in self.model_path.lower():
gen_params["lang_code"] = lang_code
# Add pitch and gender for Spark models
if "spark" in self.model_path.lower():
gen_params["pitch"] = 1.0 # Default to moderate
gen_params["gender"] = "female" # Default to female
print(f"Generation parameters: {gen_params}", file=sys.stderr)
# Generate audio using the loaded model
try:
results = self.tts_model.generate(**gen_params)
except Exception as gen_err:
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
# Process the generated audio segments
audio_arrays = []
for segment in results:
audio_arrays.append(segment.audio)
# If no segments, return error
if not audio_arrays:
print("No audio segments generated", file=sys.stderr)
return backend_pb2.Result(success=False, message="No audio generated")
# Concatenate all segments
cat_audio = np.concatenate(audio_arrays, axis=0)
# Generate output filename and path
if request.dst:
output_path = request.dst
else:
unique_id = str(uuid.uuid4())
filename = f"tts_{unique_id}.wav"
output_path = filename
# Write the audio as a WAV
try:
sf.write(output_path, cat_audio, 24000)
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
# Verify the file exists and has content
if not os.path.exists(output_path):
print(f"File was not created at {output_path}", file=sys.stderr)
return backend_pb2.Result(success=False, message="Failed to create audio file")
file_size = os.path.getsize(output_path)
if file_size == 0:
print("File was created but is empty", file=sys.stderr)
return backend_pb2.Result(success=False, message="Generated audio file is empty")
print(f"Audio file size: {file_size} bytes", file=sys.stderr)
except Exception as write_err:
print(f"Error writing audio file: {write_err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
except Exception as e:
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
async def Predict(self, request, context):
"""
Generates TTS audio based on the given prompt using MLX-Audio TTS.
This is a fallback method for compatibility with the Predict endpoint.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
try:
# Check if model is loaded
if not hasattr(self, 'tts_model') or self.tts_model is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("TTS model not loaded. Please call LoadModel first.")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# For TTS, we expect the prompt to contain the text to synthesize
if not request.Prompt:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("Prompt is required for TTS generation")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# Handle speed parameter based on model type
speed_value = self._handle_speed_parameter(request, self.model_path)
# Map language names to codes if needed
lang_code = self._map_language_code(None, None) # Use defaults for Predict
# Prepare generation parameters
gen_params = {
"text": request.Prompt,
"speed": speed_value,
"verbose": False,
}
# Add model-specific parameters
if hasattr(self, 'options') and 'voice' in self.options:
gen_params["voice"] = self.options['voice']
# Check if model supports language codes (primarily Kokoro)
if "kokoro" in self.model_path.lower():
gen_params["lang_code"] = lang_code
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
# Generate audio using the loaded model
try:
results = self.tts_model.generate(**gen_params)
except Exception as gen_err:
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"TTS generation failed: {gen_err}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
# Process the generated audio segments
audio_arrays = []
for segment in results:
audio_arrays.append(segment.audio)
# If no segments, return error
if not audio_arrays:
print("No audio segments generated", file=sys.stderr)
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
# Concatenate all segments
cat_audio = np.concatenate(audio_arrays, axis=0)
duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
# Return success message with audio information
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
except Exception as e:
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"TTS generation failed: {str(e)}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
def _handle_speed_parameter(self, request, model_path):
"""
Handle speed parameter based on model type.
Args:
request: The TTSRequest object.
model_path: The model path to determine model type.
Returns:
float: The processed speed value.
"""
# Get speed from options if available
speed = 1.0
if hasattr(self, 'options') and 'speed' in self.options:
speed = self.options['speed']
# Handle speed parameter based on model type
if "spark" in model_path.lower():
# Spark actually expects float values that map to speed descriptions
speed_map = {
"very_low": 0.0,
"low": 0.5,
"moderate": 1.0,
"high": 1.5,
"very_high": 2.0,
}
if isinstance(speed, str) and speed in speed_map:
speed_value = speed_map[speed]
else:
# Try to use as float, default to 1.0 (moderate) if invalid
try:
speed_value = float(speed)
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
speed_value = 1.0 # Default to moderate
except:
speed_value = 1.0 # Default to moderate
else:
# Other models use float speed values
try:
speed_value = float(speed)
if speed_value < 0.5 or speed_value > 2.0:
speed_value = 1.0 # Default to 1.0 if out of range
except ValueError:
speed_value = 1.0 # Default to 1.0 if invalid
return speed_value
def _map_language_code(self, language, voice):
"""
Map language names to codes if needed.
Args:
language: The language parameter from the request.
voice: The voice parameter from the request.
Returns:
str: The language code.
"""
if not language:
# Default to voice[0] if not found
return voice[0] if voice else "a"
# Map language names to codes if needed
language_map = {
"american_english": "a",
"british_english": "b",
"spanish": "e",
"french": "f",
"hindi": "h",
"italian": "i",
"portuguese": "p",
"japanese": "j",
"mandarin_chinese": "z",
# Also accept direct language codes
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
}
return language_map.get(language.lower(), language)
def _build_generation_params(self, request, default_speed=1.0):
"""
Build generation parameters from request attributes and options for MLX-Audio TTS.
Args:
request: The gRPC request.
default_speed: Default speed if not specified.
Returns:
dict: Generation parameters for MLX-Audio
"""
# Initialize generation parameters for MLX-Audio TTS
generation_params = {
'speed': default_speed,
'voice': 'af_heart', # Default voice
'lang_code': 'a', # Default language code
}
# Extract parameters from request attributes
if hasattr(request, 'Temperature') and request.Temperature > 0:
# Temperature could be mapped to speed variation
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
# Override with options if available
if hasattr(self, 'options'):
# Speed from options
if 'speed' in self.options:
generation_params['speed'] = self.options['speed']
# Voice from options
if 'voice' in self.options:
generation_params['voice'] = self.options['voice']
# Language code from options
if 'lang_code' in self.options:
generation_params['lang_code'] = self.options['lang_code']
# Model-specific parameters
param_option_mapping = {
'temp': 'speed',
'temperature': 'speed',
'top_p': 'speed', # Map top_p to speed variation
}
for option_key, param_key in param_option_mapping.items():
if option_key in self.options:
if param_key == 'speed':
# Ensure speed is within reasonable bounds
speed_val = float(self.options[option_key])
if 0.5 <= speed_val <= 2.0:
generation_params[param_key] = speed_val
return generation_params
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
asyncio.run(serve(args.addr))

View File

@@ -1,14 +0,0 @@
#!/bin/bash
set -e
USE_PIP=true
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
installRequirements

View File

@@ -1 +0,0 @@
git+https://github.com/Blaizzy/mlx-audio

View File

@@ -1,7 +0,0 @@
grpcio==1.71.0
protobuf
certifi
setuptools
mlx-audio
soundfile
numpy

View File

@@ -1,11 +0,0 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -1,142 +0,0 @@
import unittest
import subprocess
import time
import backend_pb2
import backend_pb2_grpc
import grpc
import unittest
import subprocess
import time
import grpc
import backend_pb2_grpc
import backend_pb2
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service.
This class contains methods to test the startup and shutdown of the gRPC service.
"""
def setUp(self):
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
time.sleep(10)
def tearDown(self) -> None:
self.service.terminate()
self.service.wait()
def test_server_startup(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
"""
This method tests if the TTS model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_tts_generation(self):
"""
This method tests if TTS audio is generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
# Test TTS generation
tts_req = backend_pb2.TTSRequest(
text="Hello, this is a test of the MLX-Audio TTS system.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_heart",
language="a"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()
def test_tts_with_options(self):
"""
This method tests if TTS works with various options and parameters
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(
Model="mlx-community/Kokoro-82M-4bit",
Options=["voice:af_soft", "speed:1.2", "lang_code:b"]
))
self.assertTrue(response.success)
# Test TTS generation with different voice and language
tts_req = backend_pb2.TTSRequest(
text="Hello, this is a test with British English accent.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_soft",
language="b"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("TTS with options service failed")
finally:
self.tearDown()
def test_tts_multilingual(self):
"""
This method tests if TTS works with different languages
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
self.assertTrue(response.success)
# Test Spanish TTS
tts_req = backend_pb2.TTSRequest(
text="Hola, esto es una prueba del sistema TTS MLX-Audio.",
model="mlx-community/Kokoro-82M-4bit",
voice="af_heart",
language="e"
)
tts_resp = stub.TTS(tts_req)
self.assertTrue(tts_resp.success)
self.assertIn("TTS audio generated successfully", tts_resp.message)
except Exception as err:
print(err)
self.fail("Multilingual TTS service failed")
finally:
self.tearDown()

View File

@@ -1,12 +0,0 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -40,6 +40,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except ValueError:
return False
def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
@@ -81,10 +89,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif self._is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

View File

@@ -38,6 +38,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except ValueError:
return False
def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
def Health(self, request, context):
"""
Returns a health check message.
@@ -79,10 +87,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Convert numeric values to appropriate types
if self._is_float(value):
if float(value).is_integer():
value = int(value)
else:
value = float(value)
value = float(value)
elif self._is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

View File

@@ -1,3 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.7.1
llvmlite==0.43.0
numba==0.60.0

View File

@@ -0,0 +1,9 @@
torch==2.7.1
accelerate
llvmlite==0.43.0
numba==0.60.0
transformers
bitsandbytes
outetts
sentence-transformers==5.1.0
protobuf==6.32.0

View File

@@ -1,16 +0,0 @@
package main
import (
_ "embed"
"fyne.io/fyne/v2"
)
//go:embed logo.png
var logoData []byte
// resourceIconPng is the LocalAI logo icon
var resourceIconPng = &fyne.StaticResource{
StaticName: "logo.png",
StaticContent: logoData,
}

View File

@@ -1,866 +0,0 @@
package launcher
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
)
// Config represents the launcher configuration
type Config struct {
ModelsPath string `json:"models_path"`
BackendsPath string `json:"backends_path"`
Address string `json:"address"`
AutoStart bool `json:"auto_start"`
StartOnBoot bool `json:"start_on_boot"`
LogLevel string `json:"log_level"`
EnvironmentVars map[string]string `json:"environment_vars"`
ShowWelcome *bool `json:"show_welcome"`
}
// Launcher represents the main launcher application
type Launcher struct {
// Core components
releaseManager *ReleaseManager
config *Config
ui *LauncherUI
systray *SystrayManager
ctx context.Context
window fyne.Window
app fyne.App
// Process management
localaiCmd *exec.Cmd
isRunning bool
logBuffer *strings.Builder
logMutex sync.RWMutex
statusChannel chan string
// Logging
logFile *os.File
logPath string
// UI state
lastUpdateCheck time.Time
}
// NewLauncher creates a new launcher instance
func NewLauncher(ui *LauncherUI, window fyne.Window, app fyne.App) *Launcher {
return &Launcher{
releaseManager: NewReleaseManager(),
config: &Config{},
logBuffer: &strings.Builder{},
statusChannel: make(chan string, 100),
ctx: context.Background(),
ui: ui,
window: window,
app: app,
}
}
// setupLogging sets up log file for LocalAI process output
func (l *Launcher) setupLogging() error {
// Create logs directory in data folder
dataPath := l.GetDataPath()
logsDir := filepath.Join(dataPath, "logs")
if err := os.MkdirAll(logsDir, 0755); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
// Create log file with timestamp
timestamp := time.Now().Format("2006-01-02_15-04-05")
l.logPath = filepath.Join(logsDir, fmt.Sprintf("localai_%s.log", timestamp))
logFile, err := os.Create(l.logPath)
if err != nil {
return fmt.Errorf("failed to create log file: %w", err)
}
l.logFile = logFile
return nil
}
// Initialize sets up the launcher
func (l *Launcher) Initialize() error {
if l.app == nil {
return fmt.Errorf("app is nil")
}
log.Printf("Initializing launcher...")
// Setup logging
if err := l.setupLogging(); err != nil {
return fmt.Errorf("failed to setup logging: %w", err)
}
// Load configuration
log.Printf("Loading configuration...")
if err := l.loadConfig(); err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
log.Printf("Configuration loaded, current state: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
// Clean up any partial downloads
log.Printf("Cleaning up partial downloads...")
if err := l.releaseManager.CleanupPartialDownloads(); err != nil {
log.Printf("Warning: failed to cleanup partial downloads: %v", err)
}
if l.config.StartOnBoot {
l.StartLocalAI()
}
// Set default paths if not configured (only if not already loaded from config)
if l.config.ModelsPath == "" {
homeDir, _ := os.UserHomeDir()
l.config.ModelsPath = filepath.Join(homeDir, ".localai", "models")
log.Printf("Setting default ModelsPath: %s", l.config.ModelsPath)
}
if l.config.BackendsPath == "" {
homeDir, _ := os.UserHomeDir()
l.config.BackendsPath = filepath.Join(homeDir, ".localai", "backends")
log.Printf("Setting default BackendsPath: %s", l.config.BackendsPath)
}
if l.config.Address == "" {
l.config.Address = "127.0.0.1:8080"
log.Printf("Setting default Address: %s", l.config.Address)
}
if l.config.LogLevel == "" {
l.config.LogLevel = "info"
log.Printf("Setting default LogLevel: %s", l.config.LogLevel)
}
if l.config.EnvironmentVars == nil {
l.config.EnvironmentVars = make(map[string]string)
log.Printf("Initializing empty EnvironmentVars map")
}
// Set default welcome window preference
if l.config.ShowWelcome == nil {
true := true
l.config.ShowWelcome = &true
log.Printf("Setting default ShowWelcome: true")
}
// Create directories
os.MkdirAll(l.config.ModelsPath, 0755)
os.MkdirAll(l.config.BackendsPath, 0755)
// Save the configuration with default values
if err := l.saveConfig(); err != nil {
log.Printf("Warning: failed to save default configuration: %v", err)
}
// System tray is now handled in main.go using Fyne's built-in approach
// Check if LocalAI is installed
if !l.releaseManager.IsLocalAIInstalled() {
log.Printf("No LocalAI installation found")
fyne.Do(func() {
l.updateStatus("No LocalAI installation found")
if l.ui != nil {
// Show dialog offering to download LocalAI
l.showDownloadLocalAIDialog()
}
})
}
// Check for updates periodically
go l.periodicUpdateCheck()
return nil
}
// StartLocalAI starts the LocalAI server
func (l *Launcher) StartLocalAI() error {
if l.isRunning {
return fmt.Errorf("LocalAI is already running")
}
// Verify binary integrity before starting
if err := l.releaseManager.VerifyInstalledBinary(); err != nil {
// Binary is corrupted, remove it and offer to reinstall
binaryPath := l.releaseManager.GetBinaryPath()
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
return fmt.Errorf("LocalAI binary is corrupted: %v. Please reinstall LocalAI", err)
}
binaryPath := l.releaseManager.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return fmt.Errorf("LocalAI binary not found. Please download a release first")
}
// Build command arguments
args := []string{
"run",
"--models-path", l.config.ModelsPath,
"--backends-path", l.config.BackendsPath,
"--address", l.config.Address,
"--log-level", l.config.LogLevel,
}
l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...)
// Apply environment variables
if len(l.config.EnvironmentVars) > 0 {
env := os.Environ()
for key, value := range l.config.EnvironmentVars {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
l.localaiCmd.Env = env
}
// Setup logging
stdout, err := l.localaiCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderr, err := l.localaiCmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
// Start the process
if err := l.localaiCmd.Start(); err != nil {
return fmt.Errorf("failed to start LocalAI: %w", err)
}
l.isRunning = true
fyne.Do(func() {
l.updateStatus("LocalAI is starting...")
l.updateRunningState(true)
})
// Start log monitoring
go l.monitorLogs(stdout, "STDOUT")
go l.monitorLogs(stderr, "STDERR")
// Monitor process with startup timeout
go func() {
// Wait for process to start or fail
err := l.localaiCmd.Wait()
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
if err != nil {
l.updateStatus(fmt.Sprintf("LocalAI stopped with error: %v", err))
} else {
l.updateStatus("LocalAI stopped")
}
})
}()
// Add startup timeout detection
go func() {
time.Sleep(10 * time.Second) // Wait 10 seconds for startup
if l.isRunning {
// Check if process is still alive
if l.localaiCmd.Process != nil {
if err := l.localaiCmd.Process.Signal(syscall.Signal(0)); err != nil {
// Process is dead, mark as not running
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
l.updateStatus("LocalAI failed to start properly")
})
}
}
}
}()
return nil
}
// StopLocalAI stops the LocalAI server
func (l *Launcher) StopLocalAI() error {
if !l.isRunning || l.localaiCmd == nil {
return fmt.Errorf("LocalAI is not running")
}
// Gracefully terminate the process
if err := l.localaiCmd.Process.Signal(os.Interrupt); err != nil {
// If graceful termination fails, force kill
if killErr := l.localaiCmd.Process.Kill(); killErr != nil {
return fmt.Errorf("failed to kill LocalAI process: %w", killErr)
}
}
l.isRunning = false
fyne.Do(func() {
l.updateRunningState(false)
l.updateStatus("LocalAI stopped")
})
return nil
}
// IsRunning returns whether LocalAI is currently running
func (l *Launcher) IsRunning() bool {
return l.isRunning
}
// Shutdown performs cleanup when the application is closing
func (l *Launcher) Shutdown() error {
log.Printf("Launcher shutting down, stopping LocalAI...")
// Stop LocalAI if it's running
if l.isRunning {
if err := l.StopLocalAI(); err != nil {
log.Printf("Error stopping LocalAI during shutdown: %v", err)
}
}
// Close log file if open
if l.logFile != nil {
if err := l.logFile.Close(); err != nil {
log.Printf("Error closing log file: %v", err)
}
l.logFile = nil
}
log.Printf("Launcher shutdown complete")
return nil
}
// GetLogs returns the current log buffer
func (l *Launcher) GetLogs() string {
l.logMutex.RLock()
defer l.logMutex.RUnlock()
return l.logBuffer.String()
}
// GetRecentLogs returns the most recent logs (last 50 lines) for better error display
func (l *Launcher) GetRecentLogs() string {
l.logMutex.RLock()
defer l.logMutex.RUnlock()
content := l.logBuffer.String()
lines := strings.Split(content, "\n")
// Get last 50 lines
if len(lines) > 50 {
lines = lines[len(lines)-50:]
}
return strings.Join(lines, "\n")
}
// GetConfig returns the current configuration
func (l *Launcher) GetConfig() *Config {
return l.config
}
// SetConfig updates the configuration
func (l *Launcher) SetConfig(config *Config) error {
l.config = config
return l.saveConfig()
}
func (l *Launcher) GetUI() *LauncherUI {
return l.ui
}
func (l *Launcher) SetSystray(systray *SystrayManager) {
l.systray = systray
}
// GetReleaseManager returns the release manager
func (l *Launcher) GetReleaseManager() *ReleaseManager {
return l.releaseManager
}
// GetWebUIURL returns the URL for the WebUI
func (l *Launcher) GetWebUIURL() string {
address := l.config.Address
if strings.HasPrefix(address, ":") {
address = "localhost" + address
}
if !strings.HasPrefix(address, "http") {
address = "http://" + address
}
return address
}
// GetDataPath returns the path where LocalAI data and logs are stored
func (l *Launcher) GetDataPath() string {
// LocalAI typically stores data in the current working directory or a models directory
// First check if models path is configured
if l.config != nil && l.config.ModelsPath != "" {
// Return the parent directory of models path
return filepath.Dir(l.config.ModelsPath)
}
// Fallback to home directory LocalAI folder
homeDir, err := os.UserHomeDir()
if err != nil {
return "."
}
return filepath.Join(homeDir, ".localai")
}
// CheckForUpdates checks if there are any available updates
func (l *Launcher) CheckForUpdates() (bool, string, error) {
log.Printf("CheckForUpdates: checking for available updates...")
available, version, err := l.releaseManager.IsUpdateAvailable()
if err != nil {
log.Printf("CheckForUpdates: error occurred: %v", err)
return false, "", err
}
log.Printf("CheckForUpdates: result - available=%v, version=%s", available, version)
l.lastUpdateCheck = time.Now()
return available, version, nil
}
// DownloadUpdate downloads the latest version
func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
return l.releaseManager.DownloadRelease(version, progressCallback)
}
// GetCurrentVersion returns the current installed version
func (l *Launcher) GetCurrentVersion() string {
return l.releaseManager.GetInstalledVersion()
}
// GetCurrentStatus returns the current status
func (l *Launcher) GetCurrentStatus() string {
select {
case status := <-l.statusChannel:
return status
default:
if l.isRunning {
return "LocalAI is running"
}
return "Ready"
}
}
// GetLastStatus returns the last known status without consuming from channel
func (l *Launcher) GetLastStatus() string {
if l.isRunning {
return "LocalAI is running"
}
// Check if LocalAI is installed
if !l.releaseManager.IsLocalAIInstalled() {
return "LocalAI not installed"
}
return "Ready"
}
func (l *Launcher) githubReleaseNotesURL(version string) (*url.URL, error) {
// Construct GitHub release URL
releaseURL := fmt.Sprintf("https://github.com/%s/%s/releases/tag/%s",
l.releaseManager.GitHubOwner,
l.releaseManager.GitHubRepo,
version)
// Convert string to *url.URL
return url.Parse(releaseURL)
}
// showDownloadLocalAIDialog shows a dialog offering to download LocalAI
func (l *Launcher) showDownloadLocalAIDialog() {
if l.app == nil {
log.Printf("Cannot show download dialog: app is nil")
return
}
fyne.DoAndWait(func() {
// Create a standalone window for the download dialog
dialogWindow := l.app.NewWindow("LocalAI Installation Required")
dialogWindow.Resize(fyne.NewSize(500, 350))
dialogWindow.CenterOnScreen()
dialogWindow.SetCloseIntercept(func() {
dialogWindow.Close()
})
// Create the dialog content
titleLabel := widget.NewLabel("LocalAI Not Found")
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
messageLabel := widget.NewLabel("LocalAI is not installed on your system.\n\nWould you like to download and install the latest version?")
messageLabel.Wrapping = fyne.TextWrapWord
messageLabel.Alignment = fyne.TextAlignCenter
// Buttons
downloadButton := widget.NewButton("Download & Install", func() {
dialogWindow.Close()
l.downloadAndInstallLocalAI()
if l.systray != nil {
l.systray.recreateMenu()
}
})
downloadButton.Importance = widget.HighImportance
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
// Get latest release info and open release notes
go func() {
release, err := l.releaseManager.GetLatestRelease()
if err != nil {
log.Printf("Failed to get latest release info: %v", err)
return
}
releaseNotesURL, err := l.githubReleaseNotesURL(release.Version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
l.app.OpenURL(releaseNotesURL)
}()
})
skipButton := widget.NewButton("Skip for Now", func() {
dialogWindow.Close()
})
// Layout - put release notes button above the main action buttons
actionButtons := container.NewHBox(skipButton, downloadButton)
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
messageLabel,
widget.NewSeparator(),
releaseNotesButton,
widget.NewSeparator(),
actionButtons,
)
dialogWindow.SetContent(content)
dialogWindow.Show()
})
}
// downloadAndInstallLocalAI downloads and installs the latest LocalAI version
func (l *Launcher) downloadAndInstallLocalAI() {
if l.app == nil {
log.Printf("Cannot download LocalAI: app is nil")
return
}
// First check what the latest version is
go func() {
log.Printf("Checking for latest LocalAI version...")
available, version, err := l.CheckForUpdates()
if err != nil {
log.Printf("Failed to check for updates: %v", err)
l.showDownloadError("Failed to check for latest version", err.Error())
return
}
if !available {
log.Printf("No updates available, but LocalAI is not installed")
l.showDownloadError("No Version Available", "Could not determine the latest LocalAI version. Please check your internet connection and try again.")
return
}
log.Printf("Latest version available: %s", version)
// Show progress window with the specific version
l.showDownloadProgress(version, fmt.Sprintf("Downloading LocalAI %s...", version))
}()
}
// showDownloadError shows an error dialog for download failures
func (l *Launcher) showDownloadError(title, message string) {
fyne.DoAndWait(func() {
// Create error window
errorWindow := l.app.NewWindow("Download Error")
errorWindow.Resize(fyne.NewSize(400, 200))
errorWindow.CenterOnScreen()
errorWindow.SetCloseIntercept(func() {
errorWindow.Close()
})
// Error content
titleLabel := widget.NewLabel(title)
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
messageLabel := widget.NewLabel(message)
messageLabel.Wrapping = fyne.TextWrapWord
messageLabel.Alignment = fyne.TextAlignCenter
// Close button
closeButton := widget.NewButton("Close", func() {
errorWindow.Close()
})
// Layout
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
messageLabel,
widget.NewSeparator(),
closeButton,
)
errorWindow.SetContent(content)
errorWindow.Show()
})
}
// showDownloadProgress shows a standalone progress window for downloading LocalAI
func (l *Launcher) showDownloadProgress(version, title string) {
fyne.DoAndWait(func() {
// Create progress window
progressWindow := l.app.NewWindow("Downloading LocalAI")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
progressWindow.SetCloseIntercept(func() {
progressWindow.Close()
})
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := l.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
l.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(title),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := l.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show success dialog
dialog.ShowConfirm("Installation Complete",
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
func(close bool) {
progressWindow.Close()
// Update status and refresh systray menu
l.updateStatus("LocalAI installed successfully")
if l.systray != nil {
l.systray.recreateMenu()
}
}, progressWindow)
}
})
}()
})
}
// monitorLogs monitors the output of LocalAI and adds it to the log buffer
func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
timestamp := time.Now().Format("15:04:05")
logLine := fmt.Sprintf("[%s] %s: %s\n", timestamp, prefix, line)
l.logMutex.Lock()
l.logBuffer.WriteString(logLine)
// Keep log buffer size reasonable
if l.logBuffer.Len() > 100000 { // 100KB
content := l.logBuffer.String()
// Keep last 50KB
if len(content) > 50000 {
l.logBuffer.Reset()
l.logBuffer.WriteString(content[len(content)-50000:])
}
}
l.logMutex.Unlock()
// Write to log file if available
if l.logFile != nil {
if _, err := l.logFile.WriteString(logLine); err != nil {
log.Printf("Failed to write to log file: %v", err)
}
}
fyne.Do(func() {
// Notify UI of new log content
if l.ui != nil {
l.ui.OnLogUpdate(logLine)
}
// Check for startup completion
if strings.Contains(line, "API server listening") {
l.updateStatus("LocalAI is running")
}
})
}
}
// updateStatus updates the status and notifies UI
func (l *Launcher) updateStatus(status string) {
select {
case l.statusChannel <- status:
default:
// Channel full, skip
}
if l.ui != nil {
l.ui.UpdateStatus(status)
}
if l.systray != nil {
l.systray.UpdateStatus(status)
}
}
// updateRunningState updates the running state in UI and systray
func (l *Launcher) updateRunningState(isRunning bool) {
if l.ui != nil {
l.ui.UpdateRunningState(isRunning)
}
if l.systray != nil {
l.systray.UpdateRunningState(isRunning)
}
}
// periodicUpdateCheck checks for updates periodically
func (l *Launcher) periodicUpdateCheck() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
available, version, err := l.CheckForUpdates()
if err == nil && available {
fyne.Do(func() {
l.updateStatus(fmt.Sprintf("Update available: %s", version))
if l.systray != nil {
l.systray.NotifyUpdateAvailable(version)
}
if l.ui != nil {
l.ui.NotifyUpdateAvailable(version)
}
})
}
case <-l.ctx.Done():
return
}
}
}
// loadConfig loads configuration from file
func (l *Launcher) loadConfig() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
configPath := filepath.Join(homeDir, ".localai", "launcher.json")
log.Printf("Loading config from: %s", configPath)
if _, err := os.Stat(configPath); os.IsNotExist(err) {
log.Printf("Config file not found, creating default config")
// Create default config
return l.saveConfig()
}
// Load existing config
configData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
log.Printf("Config file content: %s", string(configData))
log.Printf("loadConfig: about to unmarshal JSON data")
if err := json.Unmarshal(configData, l.config); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
log.Printf("loadConfig: JSON unmarshaled successfully")
log.Printf("Loaded config: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
log.Printf("Environment vars: %v", l.config.EnvironmentVars)
return nil
}
// saveConfig saves configuration to file
func (l *Launcher) saveConfig() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
configDir := filepath.Join(homeDir, ".localai")
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Marshal config to JSON
log.Printf("saveConfig: marshaling config with EnvironmentVars: %v", l.config.EnvironmentVars)
configData, err := json.MarshalIndent(l.config, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
log.Printf("saveConfig: JSON marshaled successfully, length: %d", len(configData))
configPath := filepath.Join(configDir, "launcher.json")
log.Printf("Saving config to: %s", configPath)
log.Printf("Config content: %s", string(configData))
if err := os.WriteFile(configPath, configData, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
log.Printf("Config saved successfully")
return nil
}

View File

@@ -1,13 +0,0 @@
package launcher_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLauncher(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Launcher Suite")
}

View File

@@ -1,213 +0,0 @@
package launcher_test
import (
"os"
"path/filepath"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"fyne.io/fyne/v2/app"
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
var _ = Describe("Launcher", func() {
var (
launcherInstance *launcher.Launcher
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "launcher-test-*")
Expect(err).ToNot(HaveOccurred())
ui := launcher.NewLauncherUI()
app := app.NewWithID("com.localai.launcher")
launcherInstance = launcher.NewLauncher(ui, nil, app)
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
Describe("NewLauncher", func() {
It("should create a launcher with default configuration", func() {
Expect(launcherInstance.GetConfig()).ToNot(BeNil())
})
})
Describe("Initialize", func() {
It("should set default paths when not configured", func() {
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
config := launcherInstance.GetConfig()
Expect(config.ModelsPath).ToNot(BeEmpty())
Expect(config.BackendsPath).ToNot(BeEmpty())
})
It("should set default ShowWelcome to true", func() {
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
config := launcherInstance.GetConfig()
Expect(config.ShowWelcome).To(BeTrue())
Expect(config.Address).To(Equal("127.0.0.1:8080"))
Expect(config.LogLevel).To(Equal("info"))
})
It("should create models and backends directories", func() {
// Set custom paths for testing
config := launcherInstance.GetConfig()
config.ModelsPath = filepath.Join(tempDir, "models")
config.BackendsPath = filepath.Join(tempDir, "backends")
launcherInstance.SetConfig(config)
err := launcherInstance.Initialize()
Expect(err).ToNot(HaveOccurred())
// Check if directories were created
_, err = os.Stat(config.ModelsPath)
Expect(err).ToNot(HaveOccurred())
_, err = os.Stat(config.BackendsPath)
Expect(err).ToNot(HaveOccurred())
})
})
Describe("Configuration", func() {
It("should get and set configuration", func() {
config := launcherInstance.GetConfig()
config.ModelsPath = "/test/models"
config.BackendsPath = "/test/backends"
config.Address = ":9090"
config.LogLevel = "debug"
err := launcherInstance.SetConfig(config)
Expect(err).ToNot(HaveOccurred())
retrievedConfig := launcherInstance.GetConfig()
Expect(retrievedConfig.ModelsPath).To(Equal("/test/models"))
Expect(retrievedConfig.BackendsPath).To(Equal("/test/backends"))
Expect(retrievedConfig.Address).To(Equal(":9090"))
Expect(retrievedConfig.LogLevel).To(Equal("debug"))
})
})
Describe("WebUI URL", func() {
It("should return correct WebUI URL for localhost", func() {
config := launcherInstance.GetConfig()
config.Address = ":8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://localhost:8080"))
})
It("should return correct WebUI URL for full address", func() {
config := launcherInstance.GetConfig()
config.Address = "127.0.0.1:8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://127.0.0.1:8080"))
})
It("should handle http prefix correctly", func() {
config := launcherInstance.GetConfig()
config.Address = "http://localhost:8080"
launcherInstance.SetConfig(config)
url := launcherInstance.GetWebUIURL()
Expect(url).To(Equal("http://localhost:8080"))
})
})
Describe("Process Management", func() {
It("should not be running initially", func() {
Expect(launcherInstance.IsRunning()).To(BeFalse())
})
It("should handle start when binary doesn't exist", func() {
err := launcherInstance.StartLocalAI()
Expect(err).To(HaveOccurred())
// Could be either "not found" or "permission denied" depending on test environment
errMsg := err.Error()
hasExpectedError := strings.Contains(errMsg, "LocalAI binary") ||
strings.Contains(errMsg, "permission denied")
Expect(hasExpectedError).To(BeTrue(), "Expected error about binary not found or permission denied, got: %s", errMsg)
})
It("should handle stop when not running", func() {
err := launcherInstance.StopLocalAI()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LocalAI is not running"))
})
})
Describe("Logs", func() {
It("should return empty logs initially", func() {
logs := launcherInstance.GetLogs()
Expect(logs).To(BeEmpty())
})
})
Describe("Version Management", func() {
It("should return empty version when no binary installed", func() {
version := launcherInstance.GetCurrentVersion()
Expect(version).To(BeEmpty()) // No binary installed in test environment
})
It("should handle update checks", func() {
// This test would require mocking HTTP responses
// For now, we'll just test that the method doesn't panic
_, _, err := launcherInstance.CheckForUpdates()
// We expect either success or a network error, not a panic
if err != nil {
// Network error is acceptable in tests
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
}
})
})
})
var _ = Describe("Config", func() {
It("should have proper JSON tags", func() {
config := &launcher.Config{
ModelsPath: "/test/models",
BackendsPath: "/test/backends",
Address: ":8080",
AutoStart: true,
LogLevel: "info",
EnvironmentVars: map[string]string{"TEST": "value"},
}
Expect(config.ModelsPath).To(Equal("/test/models"))
Expect(config.BackendsPath).To(Equal("/test/backends"))
Expect(config.Address).To(Equal(":8080"))
Expect(config.AutoStart).To(BeTrue())
Expect(config.LogLevel).To(Equal("info"))
Expect(config.EnvironmentVars).To(HaveKeyWithValue("TEST", "value"))
})
It("should initialize environment variables map", func() {
config := &launcher.Config{}
Expect(config.EnvironmentVars).To(BeNil())
ui := launcher.NewLauncherUI()
app := app.NewWithID("com.localai.launcher")
launcher := launcher.NewLauncher(ui, nil, app)
err := launcher.Initialize()
Expect(err).ToNot(HaveOccurred())
retrievedConfig := launcher.GetConfig()
Expect(retrievedConfig.EnvironmentVars).ToNot(BeNil())
Expect(retrievedConfig.EnvironmentVars).To(BeEmpty())
})
})

View File

@@ -1,502 +0,0 @@
package launcher
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/mudler/LocalAI/internal"
)
// Release represents a LocalAI release
type Release struct {
Version string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
PublishedAt time.Time `json:"published_at"`
Assets []Asset `json:"assets"`
}
// Asset represents a release asset
type Asset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
}
// ReleaseManager handles LocalAI release management
type ReleaseManager struct {
// GitHubOwner is the GitHub repository owner
GitHubOwner string
// GitHubRepo is the GitHub repository name
GitHubRepo string
// BinaryPath is where the LocalAI binary is stored locally
BinaryPath string
// CurrentVersion is the currently installed version
CurrentVersion string
// ChecksumsPath is where checksums are stored
ChecksumsPath string
// MetadataPath is where version metadata is stored
MetadataPath string
}
// NewReleaseManager creates a new release manager
func NewReleaseManager() *ReleaseManager {
homeDir, _ := os.UserHomeDir()
binaryPath := filepath.Join(homeDir, ".localai", "bin")
checksumsPath := filepath.Join(homeDir, ".localai", "checksums")
metadataPath := filepath.Join(homeDir, ".localai", "metadata")
return &ReleaseManager{
GitHubOwner: "mudler",
GitHubRepo: "LocalAI",
BinaryPath: binaryPath,
CurrentVersion: internal.PrintableVersion(),
ChecksumsPath: checksumsPath,
MetadataPath: metadataPath,
}
}
// GetLatestRelease fetches the latest release information from GitHub
func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
}
// Parse the JSON response properly
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
release := &Release{}
if err := json.Unmarshal(body, release); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
// Validate the release data
if release.Version == "" {
return nil, fmt.Errorf("no version found in release data")
}
return release, nil
}
// DownloadRelease downloads a specific version of LocalAI
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
// Ensure the binary directory exists
if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
return fmt.Errorf("failed to create binary directory: %w", err)
}
// Determine the binary name based on OS and architecture
binaryName := rm.GetBinaryName(version)
localPath := filepath.Join(rm.BinaryPath, "local-ai")
// Download the binary
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
return fmt.Errorf("failed to download binary: %w", err)
}
// Download and verify checksums
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
rm.GitHubOwner, rm.GitHubRepo, version, version)
checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
if err := rm.downloadFile(checksumURL, checksumPath, nil); err != nil {
return fmt.Errorf("failed to download checksums: %w", err)
}
// Verify the checksum
if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
// Save checksums persistently for future verification
if err := rm.saveChecksums(version, checksumPath, binaryName); err != nil {
log.Printf("Warning: failed to save checksums: %v", err)
}
// Make the binary executable
if err := os.Chmod(localPath, 0755); err != nil {
return fmt.Errorf("failed to make binary executable: %w", err)
}
return nil
}
// GetBinaryName returns the appropriate binary name for the current platform
func (rm *ReleaseManager) GetBinaryName(version string) string {
versionStr := strings.TrimPrefix(version, "v")
os := runtime.GOOS
arch := runtime.GOARCH
// Map Go arch names to the release naming convention
switch arch {
case "amd64":
arch = "amd64"
case "arm64":
arch = "arm64"
default:
arch = "amd64" // fallback
}
return fmt.Sprintf("local-ai-v%s-%s-%s", versionStr, os, arch)
}
// downloadFile downloads a file from a URL to a local path with optional progress callback
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
}
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
// Create a progress reader if callback is provided
var reader io.Reader = resp.Body
if progressCallback != nil && resp.ContentLength > 0 {
reader = &progressReader{
Reader: resp.Body,
Total: resp.ContentLength,
Callback: progressCallback,
}
}
_, err = io.Copy(out, reader)
return err
}
// saveChecksums saves checksums persistently for future verification
func (rm *ReleaseManager) saveChecksums(version, checksumPath, binaryName string) error {
// Ensure checksums directory exists
if err := os.MkdirAll(rm.ChecksumsPath, 0755); err != nil {
return fmt.Errorf("failed to create checksums directory: %w", err)
}
// Read the downloaded checksums file
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return fmt.Errorf("failed to read checksums file: %w", err)
}
// Save to persistent location with version info
persistentPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
if err := os.WriteFile(persistentPath, checksumData, 0644); err != nil {
return fmt.Errorf("failed to write persistent checksums: %w", err)
}
// Also save a "latest" checksums file for the current version
latestPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
if err := os.WriteFile(latestPath, checksumData, 0644); err != nil {
return fmt.Errorf("failed to write latest checksums: %w", err)
}
// Save version metadata
if err := rm.saveVersionMetadata(version); err != nil {
log.Printf("Warning: failed to save version metadata: %v", err)
}
log.Printf("Checksums saved for version %s", version)
return nil
}
// saveVersionMetadata saves the installed version information
func (rm *ReleaseManager) saveVersionMetadata(version string) error {
// Ensure metadata directory exists
if err := os.MkdirAll(rm.MetadataPath, 0755); err != nil {
return fmt.Errorf("failed to create metadata directory: %w", err)
}
// Create metadata structure
metadata := struct {
Version string `json:"version"`
InstalledAt time.Time `json:"installed_at"`
BinaryPath string `json:"binary_path"`
}{
Version: version,
InstalledAt: time.Now(),
BinaryPath: rm.GetBinaryPath(),
}
// Marshal to JSON
metadataData, err := json.MarshalIndent(metadata, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
// Save metadata file
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
if err := os.WriteFile(metadataPath, metadataData, 0644); err != nil {
return fmt.Errorf("failed to write metadata file: %w", err)
}
log.Printf("Version metadata saved: %s", version)
return nil
}
// progressReader wraps an io.Reader to provide download progress
type progressReader struct {
io.Reader
Total int64
Current int64
Callback func(float64)
}
func (pr *progressReader) Read(p []byte) (int, error) {
n, err := pr.Reader.Read(p)
pr.Current += int64(n)
if pr.Callback != nil {
progress := float64(pr.Current) / float64(pr.Total)
pr.Callback(progress)
}
return n, err
}
// VerifyChecksum verifies the downloaded file against the provided checksums
func (rm *ReleaseManager) VerifyChecksum(filePath, checksumPath, binaryName string) error {
// Calculate the SHA256 of the downloaded file
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file for checksum: %w", err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return fmt.Errorf("failed to calculate checksum: %w", err)
}
calculatedHash := hex.EncodeToString(hasher.Sum(nil))
// Read the checksums file
checksumFile, err := os.Open(checksumPath)
if err != nil {
return fmt.Errorf("failed to open checksums file: %w", err)
}
defer checksumFile.Close()
scanner := bufio.NewScanner(checksumFile)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, binaryName) {
parts := strings.Fields(line)
if len(parts) >= 2 {
expectedHash := parts[0]
if calculatedHash == expectedHash {
return nil // Checksum verified
}
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, calculatedHash)
}
}
}
return fmt.Errorf("checksum not found for %s", binaryName)
}
// GetInstalledVersion returns the currently installed version
func (rm *ReleaseManager) GetInstalledVersion() string {
// Fallback: Check if the LocalAI binary exists and try to get its version
binaryPath := rm.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return "" // No version installed
}
// try to get version from metadata
if version := rm.loadVersionMetadata(); version != "" {
return version
}
// Try to run the binary to get the version (fallback method)
version, err := exec.Command(binaryPath, "--version").Output()
if err != nil {
// If binary exists but --version fails, try to determine from filename or other means
log.Printf("Binary exists but --version failed: %v", err)
return ""
}
stringVersion := strings.TrimSpace(string(version))
stringVersion = strings.TrimRight(stringVersion, "\n")
return stringVersion
}
// loadVersionMetadata loads the installed version from metadata file
func (rm *ReleaseManager) loadVersionMetadata() string {
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
// Check if metadata file exists
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
return ""
}
// Read metadata file
metadataData, err := os.ReadFile(metadataPath)
if err != nil {
log.Printf("Failed to read metadata file: %v", err)
return ""
}
// Parse metadata
var metadata struct {
Version string `json:"version"`
InstalledAt time.Time `json:"installed_at"`
BinaryPath string `json:"binary_path"`
}
if err := json.Unmarshal(metadataData, &metadata); err != nil {
log.Printf("Failed to parse metadata file: %v", err)
return ""
}
// Verify that the binary path in metadata matches current binary path
if metadata.BinaryPath != rm.GetBinaryPath() {
log.Printf("Binary path mismatch in metadata, ignoring")
return ""
}
log.Printf("Loaded version from metadata: %s (installed at %s)", metadata.Version, metadata.InstalledAt.Format("2006-01-02 15:04:05"))
return metadata.Version
}
// GetBinaryPath returns the path to the LocalAI binary
func (rm *ReleaseManager) GetBinaryPath() string {
return filepath.Join(rm.BinaryPath, "local-ai")
}
// IsUpdateAvailable checks if an update is available
func (rm *ReleaseManager) IsUpdateAvailable() (bool, string, error) {
log.Printf("IsUpdateAvailable: checking for updates...")
latest, err := rm.GetLatestRelease()
if err != nil {
log.Printf("IsUpdateAvailable: failed to get latest release: %v", err)
return false, "", err
}
log.Printf("IsUpdateAvailable: latest release version: %s", latest.Version)
current := rm.GetInstalledVersion()
log.Printf("IsUpdateAvailable: current installed version: %s", current)
if current == "" {
// No version installed, offer to download latest
log.Printf("IsUpdateAvailable: no version installed, offering latest: %s", latest.Version)
return true, latest.Version, nil
}
updateAvailable := latest.Version != current
log.Printf("IsUpdateAvailable: update available: %v (latest: %s, current: %s)", updateAvailable, latest.Version, current)
return updateAvailable, latest.Version, nil
}
// IsLocalAIInstalled checks if LocalAI binary exists and is valid
func (rm *ReleaseManager) IsLocalAIInstalled() bool {
binaryPath := rm.GetBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
return false
}
// Verify the binary integrity
if err := rm.VerifyInstalledBinary(); err != nil {
log.Printf("Binary integrity check failed: %v", err)
// Remove corrupted binary
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
return false
}
return true
}
// VerifyInstalledBinary verifies the installed binary against saved checksums
func (rm *ReleaseManager) VerifyInstalledBinary() error {
binaryPath := rm.GetBinaryPath()
// Check if we have saved checksums
latestChecksumsPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
if _, err := os.Stat(latestChecksumsPath); os.IsNotExist(err) {
return fmt.Errorf("no saved checksums found")
}
// Get the binary name for the current version from metadata
currentVersion := rm.loadVersionMetadata()
if currentVersion == "" {
return fmt.Errorf("cannot determine current version from metadata")
}
binaryName := rm.GetBinaryName(currentVersion)
// Verify against saved checksums
return rm.VerifyChecksum(binaryPath, latestChecksumsPath, binaryName)
}
// CleanupPartialDownloads removes any partial or corrupted downloads
func (rm *ReleaseManager) CleanupPartialDownloads() error {
binaryPath := rm.GetBinaryPath()
// Check if binary exists but is corrupted
if _, err := os.Stat(binaryPath); err == nil {
// Binary exists, verify it
if verifyErr := rm.VerifyInstalledBinary(); verifyErr != nil {
log.Printf("Found corrupted binary, removing: %v", verifyErr)
if removeErr := os.Remove(binaryPath); removeErr != nil {
log.Printf("Failed to remove corrupted binary: %v", removeErr)
}
// Clear metadata since binary is corrupted
rm.clearVersionMetadata()
}
}
// Clean up any temporary checksum files
tempChecksumsPath := filepath.Join(rm.BinaryPath, "checksums.txt")
if _, err := os.Stat(tempChecksumsPath); err == nil {
if removeErr := os.Remove(tempChecksumsPath); removeErr != nil {
log.Printf("Failed to remove temporary checksums: %v", removeErr)
}
}
return nil
}
// clearVersionMetadata clears the version metadata (used when binary is corrupted or removed)
func (rm *ReleaseManager) clearVersionMetadata() {
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
log.Printf("Failed to clear version metadata: %v", err)
} else {
log.Printf("Version metadata cleared")
}
}

View File

@@ -1,178 +0,0 @@
package launcher_test
import (
"os"
"path/filepath"
"runtime"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
var _ = Describe("ReleaseManager", func() {
var (
rm *launcher.ReleaseManager
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "launcher-test-*")
Expect(err).ToNot(HaveOccurred())
rm = launcher.NewReleaseManager()
// Override binary path for testing
rm.BinaryPath = tempDir
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
Describe("NewReleaseManager", func() {
It("should create a release manager with correct defaults", func() {
newRM := launcher.NewReleaseManager()
Expect(newRM.GitHubOwner).To(Equal("mudler"))
Expect(newRM.GitHubRepo).To(Equal("LocalAI"))
Expect(newRM.BinaryPath).To(ContainSubstring(".localai"))
})
})
Describe("GetBinaryName", func() {
It("should return correct binary name for current platform", func() {
binaryName := rm.GetBinaryName("v3.4.0")
expectedOS := runtime.GOOS
expectedArch := runtime.GOARCH
expected := "local-ai-v3.4.0-" + expectedOS + "-" + expectedArch
Expect(binaryName).To(Equal(expected))
})
It("should handle version with and without 'v' prefix", func() {
withV := rm.GetBinaryName("v3.4.0")
withoutV := rm.GetBinaryName("3.4.0")
// Both should produce the same result
Expect(withV).To(Equal(withoutV))
})
})
Describe("GetBinaryPath", func() {
It("should return the correct binary path", func() {
path := rm.GetBinaryPath()
expected := filepath.Join(tempDir, "local-ai")
Expect(path).To(Equal(expected))
})
})
Describe("GetInstalledVersion", func() {
It("should return empty when no binary exists", func() {
version := rm.GetInstalledVersion()
Expect(version).To(BeEmpty()) // No binary installed in test
})
It("should return empty version when binary exists but no metadata", func() {
// Create a fake binary for testing
err := os.MkdirAll(rm.BinaryPath, 0755)
Expect(err).ToNot(HaveOccurred())
binaryPath := rm.GetBinaryPath()
err = os.WriteFile(binaryPath, []byte("fake binary"), 0755)
Expect(err).ToNot(HaveOccurred())
version := rm.GetInstalledVersion()
Expect(version).To(BeEmpty())
})
})
Context("with mocked responses", func() {
// Note: In a real implementation, we'd mock HTTP responses
// For now, we'll test the structure and error handling
Describe("GetLatestRelease", func() {
It("should handle network errors gracefully", func() {
// This test would require mocking HTTP client
// For demonstration, we're just testing the method exists
_, err := rm.GetLatestRelease()
// We expect either success or a network error, not a panic
// In a real test, we'd mock the HTTP response
if err != nil {
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
}
})
})
Describe("DownloadRelease", func() {
It("should create binary directory if it doesn't exist", func() {
// Remove the temp directory to test creation
os.RemoveAll(tempDir)
// This will fail due to network, but should create the directory
rm.DownloadRelease("v3.4.0", nil)
// Check if directory was created
_, err := os.Stat(tempDir)
Expect(err).ToNot(HaveOccurred())
})
})
})
Describe("VerifyChecksum functionality", func() {
var (
testFile string
checksumFile string
)
BeforeEach(func() {
testFile = filepath.Join(tempDir, "test-binary")
checksumFile = filepath.Join(tempDir, "checksums.txt")
})
It("should verify checksums correctly", func() {
// Create a test file with known content
testContent := []byte("test content for checksum")
err := os.WriteFile(testFile, testContent, 0644)
Expect(err).ToNot(HaveOccurred())
// Calculate expected SHA256
// This is a simplified test - in practice we'd use the actual checksum
checksumContent := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 test-binary\n"
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
Expect(err).ToNot(HaveOccurred())
// Test checksum verification
// Note: This will fail because our content doesn't match the empty string hash
// In a real test, we'd calculate the actual hash
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
// We expect this to fail since we're using a dummy checksum
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("checksum mismatch"))
})
It("should handle missing checksum file", func() {
// Create test file but no checksum file
err := os.WriteFile(testFile, []byte("test"), 0644)
Expect(err).ToNot(HaveOccurred())
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to open checksums file"))
})
It("should handle missing binary in checksums", func() {
// Create files but checksum doesn't contain our binary
err := os.WriteFile(testFile, []byte("test"), 0644)
Expect(err).ToNot(HaveOccurred())
checksumContent := "hash other-binary\n"
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
Expect(err).ToNot(HaveOccurred())
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("checksum not found"))
})
})
})

View File

@@ -1,523 +0,0 @@
package launcher
import (
"fmt"
"log"
"net/url"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/driver/desktop"
"fyne.io/fyne/v2/widget"
)
// SystrayManager manages the system tray functionality
type SystrayManager struct {
launcher *Launcher
window fyne.Window
app fyne.App
desk desktop.App
// Menu items that need dynamic updates
startStopItem *fyne.MenuItem
hasUpdateAvailable bool
latestVersion string
icon *fyne.StaticResource
}
// NewSystrayManager creates a new systray manager
func NewSystrayManager(launcher *Launcher, window fyne.Window, desktop desktop.App, app fyne.App, icon *fyne.StaticResource) *SystrayManager {
sm := &SystrayManager{
launcher: launcher,
window: window,
app: app,
desk: desktop,
icon: icon,
}
sm.setupMenu(desktop)
return sm
}
// setupMenu sets up the system tray menu
func (sm *SystrayManager) setupMenu(desk desktop.App) {
sm.desk = desk
// Create the start/stop toggle item
sm.startStopItem = fyne.NewMenuItem("Start LocalAI", func() {
sm.toggleLocalAI()
})
desk.SetSystemTrayIcon(sm.icon)
// Initialize the menu state using recreateMenu
sm.recreateMenu()
}
// toggleLocalAI starts or stops LocalAI based on current state
func (sm *SystrayManager) toggleLocalAI() {
if sm.launcher.IsRunning() {
go func() {
if err := sm.launcher.StopLocalAI(); err != nil {
log.Printf("Failed to stop LocalAI: %v", err)
sm.showErrorDialog("Failed to Stop LocalAI", err.Error())
}
}()
} else {
go func() {
if err := sm.launcher.StartLocalAI(); err != nil {
log.Printf("Failed to start LocalAI: %v", err)
sm.showStartupErrorDialog(err)
}
}()
}
}
// openWebUI opens the LocalAI WebUI in the default browser
func (sm *SystrayManager) openWebUI() {
if !sm.launcher.IsRunning() {
return // LocalAI is not running
}
webURL := sm.launcher.GetWebUIURL()
if parsedURL, err := url.Parse(webURL); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// openDocumentation opens the LocalAI documentation
func (sm *SystrayManager) openDocumentation() {
if parsedURL, err := url.Parse("https://localai.io"); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// updateStartStopItem updates the start/stop menu item based on current state
func (sm *SystrayManager) updateStartStopItem() {
// Since Fyne menu items can't change text dynamically, we recreate the menu
sm.recreateMenu()
}
// recreateMenu recreates the entire menu with updated state
func (sm *SystrayManager) recreateMenu() {
if sm.desk == nil {
return
}
// Determine the action based on LocalAI installation and running state
var actionItem *fyne.MenuItem
if !sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
// LocalAI not installed - show install option
actionItem = fyne.NewMenuItem("📥 Install Latest Version", func() {
sm.launcher.showDownloadLocalAIDialog()
})
} else if sm.launcher.IsRunning() {
// LocalAI is running - show stop option
actionItem = fyne.NewMenuItem("🛑 Stop LocalAI", func() {
sm.toggleLocalAI()
})
} else {
// LocalAI is installed but not running - show start option
actionItem = fyne.NewMenuItem("▶️ Start LocalAI", func() {
sm.toggleLocalAI()
})
}
menuItems := []*fyne.MenuItem{}
// Add status at the top (clickable for details)
status := sm.launcher.GetLastStatus()
statusText := sm.truncateText(status, 30)
statusItem := fyne.NewMenuItem("📊 Status: "+statusText, func() {
sm.showStatusDetails(status, "")
})
menuItems = append(menuItems, statusItem)
// Only show version if LocalAI is installed
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
version := sm.launcher.GetCurrentVersion()
versionText := sm.truncateText(version, 25)
versionItem := fyne.NewMenuItem("🔧 Version: "+versionText, func() {
sm.showStatusDetails(status, version)
})
menuItems = append(menuItems, versionItem)
}
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
// Add update notification if available
if sm.hasUpdateAvailable {
updateItem := fyne.NewMenuItem("🔔 New version available ("+sm.latestVersion+")", func() {
sm.downloadUpdate()
})
menuItems = append(menuItems, updateItem)
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
}
// Core actions
menuItems = append(menuItems,
actionItem,
)
// Only show WebUI option if LocalAI is installed
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() && sm.launcher.IsRunning() {
menuItems = append(menuItems,
fyne.NewMenuItem("Open WebUI", func() {
sm.openWebUI()
}),
)
}
menuItems = append(menuItems,
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Check for Updates", func() {
sm.checkForUpdates()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Settings", func() {
sm.showSettings()
}),
fyne.NewMenuItem("Show Welcome Window", func() {
sm.showWelcomeWindow()
}),
fyne.NewMenuItem("Open Data Folder", func() {
sm.openDataFolder()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Documentation", func() {
sm.openDocumentation()
}),
fyne.NewMenuItemSeparator(),
fyne.NewMenuItem("Quit", func() {
// Perform cleanup before quitting
if err := sm.launcher.Shutdown(); err != nil {
log.Printf("Error during shutdown: %v", err)
}
sm.app.Quit()
}),
)
menu := fyne.NewMenu("LocalAI", menuItems...)
sm.desk.SetSystemTrayMenu(menu)
}
// UpdateRunningState updates the systray based on running state
func (sm *SystrayManager) UpdateRunningState(isRunning bool) {
sm.updateStartStopItem()
}
// UpdateStatus updates the systray menu to reflect status changes
func (sm *SystrayManager) UpdateStatus(status string) {
sm.recreateMenu()
}
// checkForUpdates checks for available updates
func (sm *SystrayManager) checkForUpdates() {
go func() {
log.Printf("Checking for updates...")
available, version, err := sm.launcher.CheckForUpdates()
if err != nil {
log.Printf("Failed to check for updates: %v", err)
return
}
log.Printf("Update check result: available=%v, version=%s", available, version)
if available {
sm.hasUpdateAvailable = true
sm.latestVersion = version
sm.recreateMenu()
}
}()
}
// downloadUpdate downloads the latest update
func (sm *SystrayManager) downloadUpdate() {
if !sm.hasUpdateAvailable {
return
}
// Show progress window
sm.showDownloadProgress(sm.latestVersion)
}
// showSettings shows the settings window
func (sm *SystrayManager) showSettings() {
sm.window.Show()
sm.window.RequestFocus()
}
// showWelcomeWindow shows the welcome window
func (sm *SystrayManager) showWelcomeWindow() {
if sm.launcher.GetUI() != nil {
sm.launcher.GetUI().ShowWelcomeWindow()
}
}
// openDataFolder opens the data folder in file manager
func (sm *SystrayManager) openDataFolder() {
dataPath := sm.launcher.GetDataPath()
if parsedURL, err := url.Parse("file://" + dataPath); err == nil {
sm.app.OpenURL(parsedURL)
}
}
// NotifyUpdateAvailable sets update notification in systray
func (sm *SystrayManager) NotifyUpdateAvailable(version string) {
sm.hasUpdateAvailable = true
sm.latestVersion = version
sm.recreateMenu()
}
// truncateText truncates text to specified length and adds ellipsis if needed
func (sm *SystrayManager) truncateText(text string, maxLength int) string {
if len(text) <= maxLength {
return text
}
return text[:maxLength-3] + "..."
}
// showStatusDetails shows a detailed status window with full information
func (sm *SystrayManager) showStatusDetails(status, version string) {
fyne.DoAndWait(func() {
// Create status details window
statusWindow := sm.app.NewWindow("LocalAI Status Details")
statusWindow.Resize(fyne.NewSize(500, 400))
statusWindow.CenterOnScreen()
// Status information
statusLabel := widget.NewLabel("Current Status:")
statusValue := widget.NewLabel(status)
statusValue.Wrapping = fyne.TextWrapWord
// Version information (only show if version exists)
var versionContainer fyne.CanvasObject
if version != "" {
versionLabel := widget.NewLabel("Installed Version:")
versionValue := widget.NewLabel(version)
versionValue.Wrapping = fyne.TextWrapWord
versionContainer = container.NewVBox(versionLabel, versionValue)
}
// Running state
runningLabel := widget.NewLabel("Running State:")
runningValue := widget.NewLabel("")
if sm.launcher.IsRunning() {
runningValue.SetText("🟢 Running")
} else {
runningValue.SetText("🔴 Stopped")
}
// WebUI URL
webuiLabel := widget.NewLabel("WebUI URL:")
webuiValue := widget.NewLabel(sm.launcher.GetWebUIURL())
webuiValue.Wrapping = fyne.TextWrapWord
// Recent logs (last 20 lines)
logsLabel := widget.NewLabel("Recent Logs:")
logsText := widget.NewMultiLineEntry()
logsText.SetText(sm.launcher.GetRecentLogs())
logsText.Wrapping = fyne.TextWrapWord
logsText.Disable() // Make it read-only
// Buttons
closeButton := widget.NewButton("Close", func() {
statusWindow.Close()
})
refreshButton := widget.NewButton("Refresh", func() {
// Refresh the status information
statusValue.SetText(sm.launcher.GetLastStatus())
// Note: Version refresh is not implemented for simplicity
// The version will be updated when the status details window is reopened
if sm.launcher.IsRunning() {
runningValue.SetText("🟢 Running")
} else {
runningValue.SetText("🔴 Stopped")
}
logsText.SetText(sm.launcher.GetRecentLogs())
})
openWebUIButton := widget.NewButton("Open WebUI", func() {
sm.openWebUI()
})
// Layout
buttons := container.NewHBox(closeButton, refreshButton, openWebUIButton)
// Build info container dynamically
infoItems := []fyne.CanvasObject{
statusLabel, statusValue,
widget.NewSeparator(),
}
// Add version section if it exists
if versionContainer != nil {
infoItems = append(infoItems, versionContainer, widget.NewSeparator())
}
infoItems = append(infoItems,
runningLabel, runningValue,
widget.NewSeparator(),
webuiLabel, webuiValue,
)
infoContainer := container.NewVBox(infoItems...)
content := container.NewVBox(
infoContainer,
widget.NewSeparator(),
logsLabel,
logsText,
widget.NewSeparator(),
buttons,
)
statusWindow.SetContent(content)
statusWindow.Show()
})
}
// showErrorDialog shows a simple error dialog
func (sm *SystrayManager) showErrorDialog(title, message string) {
fyne.DoAndWait(func() {
dialog.ShowError(fmt.Errorf(message), sm.window)
})
}
// showStartupErrorDialog shows a detailed error dialog with process logs
func (sm *SystrayManager) showStartupErrorDialog(err error) {
fyne.DoAndWait(func() {
// Get the recent process logs (more useful for debugging)
logs := sm.launcher.GetRecentLogs()
// Create error window
errorWindow := sm.app.NewWindow("LocalAI Startup Failed")
errorWindow.Resize(fyne.NewSize(600, 500))
errorWindow.CenterOnScreen()
// Error message
errorLabel := widget.NewLabel(fmt.Sprintf("Failed to start LocalAI:\n%s", err.Error()))
errorLabel.Wrapping = fyne.TextWrapWord
// Logs display
logsLabel := widget.NewLabel("Process Logs:")
logsText := widget.NewMultiLineEntry()
logsText.SetText(logs)
logsText.Wrapping = fyne.TextWrapWord
logsText.Disable() // Make it read-only
// Buttons
closeButton := widget.NewButton("Close", func() {
errorWindow.Close()
})
retryButton := widget.NewButton("Retry", func() {
errorWindow.Close()
// Try to start again
go func() {
if retryErr := sm.launcher.StartLocalAI(); retryErr != nil {
sm.showStartupErrorDialog(retryErr)
}
}()
})
openLogsButton := widget.NewButton("Open Logs Folder", func() {
sm.openDataFolder()
})
// Layout
buttons := container.NewHBox(closeButton, retryButton, openLogsButton)
content := container.NewVBox(
errorLabel,
widget.NewSeparator(),
logsLabel,
logsText,
widget.NewSeparator(),
buttons,
)
errorWindow.SetContent(content)
errorWindow.Show()
})
}
// showDownloadProgress shows a progress window for downloading updates
func (sm *SystrayManager) showDownloadProgress(version string) {
// Create a new window for download progress
progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
sm.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := sm.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show restart dialog
dialog.ShowConfirm("Update Downloaded",
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
func(restart bool) {
if restart {
sm.app.Quit()
}
progressWindow.Close()
}, progressWindow)
}
})
// Update systray menu
if err == nil {
sm.hasUpdateAvailable = false
sm.latestVersion = ""
sm.recreateMenu()
}
}()
}

View File

@@ -1,795 +0,0 @@
package launcher
import (
"fmt"
"log"
"net/url"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
)
// EnvVar represents an environment variable
type EnvVar struct {
Key string
Value string
}
// LauncherUI handles the user interface
type LauncherUI struct {
// Status display
statusLabel *widget.Label
versionLabel *widget.Label
// Control buttons
startStopButton *widget.Button
webUIButton *widget.Button
updateButton *widget.Button
downloadButton *widget.Button
// Configuration
modelsPathEntry *widget.Entry
backendsPathEntry *widget.Entry
addressEntry *widget.Entry
logLevelSelect *widget.Select
startOnBootCheck *widget.Check
// Environment Variables
envVarsData []EnvVar
newEnvKeyEntry *widget.Entry
newEnvValueEntry *widget.Entry
updateEnvironmentDisplay func()
// Logs
logText *widget.Entry
// Progress
progressBar *widget.ProgressBar
// Update management
latestVersion string
// Reference to launcher
launcher *Launcher
}
// NewLauncherUI creates a new UI instance
func NewLauncherUI() *LauncherUI {
return &LauncherUI{
statusLabel: widget.NewLabel("Initializing..."),
versionLabel: widget.NewLabel("Version: Unknown"),
startStopButton: widget.NewButton("Start LocalAI", nil),
webUIButton: widget.NewButton("Open WebUI", nil),
updateButton: widget.NewButton("Check for Updates", nil),
modelsPathEntry: widget.NewEntry(),
backendsPathEntry: widget.NewEntry(),
addressEntry: widget.NewEntry(),
logLevelSelect: widget.NewSelect([]string{"error", "warn", "info", "debug", "trace"}, nil),
startOnBootCheck: widget.NewCheck("Start LocalAI on system boot", nil),
logText: widget.NewMultiLineEntry(),
progressBar: widget.NewProgressBar(),
envVarsData: []EnvVar{}, // Initialize the environment variables slice
}
}
// CreateMainUI creates the main UI layout
func (ui *LauncherUI) CreateMainUI(launcher *Launcher) *fyne.Container {
ui.launcher = launcher
ui.setupBindings()
// Main tab with status and controls
// Configuration is now the main content
configTab := ui.createConfigTab()
// Create a simple container instead of tabs since we only have settings
tabs := container.NewVBox(
widget.NewCard("LocalAI Launcher Settings", "", configTab),
)
return tabs
}
// createConfigTab creates the configuration tab
func (ui *LauncherUI) createConfigTab() *fyne.Container {
// Path configuration
pathsCard := widget.NewCard("Paths", "", container.NewGridWithColumns(2,
widget.NewLabel("Models Path:"),
ui.modelsPathEntry,
widget.NewLabel("Backends Path:"),
ui.backendsPathEntry,
))
// Server configuration
serverCard := widget.NewCard("Server", "", container.NewVBox(
container.NewGridWithColumns(2,
widget.NewLabel("Address:"),
ui.addressEntry,
widget.NewLabel("Log Level:"),
ui.logLevelSelect,
),
ui.startOnBootCheck,
))
// Save button
saveButton := widget.NewButton("Save Configuration", func() {
ui.saveConfiguration()
})
// Environment Variables section
envCard := ui.createEnvironmentSection()
return container.NewVBox(
pathsCard,
serverCard,
envCard,
saveButton,
)
}
// createEnvironmentSection creates the environment variables section for the config tab
func (ui *LauncherUI) createEnvironmentSection() *fyne.Container {
// Initialize environment variables widgets
ui.newEnvKeyEntry = widget.NewEntry()
ui.newEnvKeyEntry.SetPlaceHolder("Environment Variable Name")
ui.newEnvValueEntry = widget.NewEntry()
ui.newEnvValueEntry.SetPlaceHolder("Environment Variable Value")
// Add button
addButton := widget.NewButton("Add Environment Variable", func() {
ui.addEnvironmentVariable()
})
// Environment variables list with delete buttons
ui.envVarsData = []EnvVar{}
// Create container for environment variables
envVarsContainer := container.NewVBox()
// Update function to rebuild the environment variables display
ui.updateEnvironmentDisplay = func() {
envVarsContainer.Objects = nil
for i, envVar := range ui.envVarsData {
index := i // Capture index for closure
// Create row with label and delete button
envLabel := widget.NewLabel(fmt.Sprintf("%s = %s", envVar.Key, envVar.Value))
deleteBtn := widget.NewButton("Delete", func() {
ui.confirmDeleteEnvironmentVariable(index)
})
deleteBtn.Importance = widget.DangerImportance
row := container.NewBorder(nil, nil, nil, deleteBtn, envLabel)
envVarsContainer.Add(row)
}
envVarsContainer.Refresh()
}
// Create a scrollable container for the environment variables
envScroll := container.NewScroll(envVarsContainer)
envScroll.SetMinSize(fyne.NewSize(400, 150))
// Input section for adding new environment variables
inputSection := container.NewVBox(
container.NewGridWithColumns(2,
ui.newEnvKeyEntry,
ui.newEnvValueEntry,
),
addButton,
)
// Environment variables card
envCard := widget.NewCard("Environment Variables", "", container.NewVBox(
inputSection,
widget.NewSeparator(),
envScroll,
))
return container.NewVBox(envCard)
}
// addEnvironmentVariable adds a new environment variable
func (ui *LauncherUI) addEnvironmentVariable() {
key := ui.newEnvKeyEntry.Text
value := ui.newEnvValueEntry.Text
log.Printf("addEnvironmentVariable: attempting to add %s=%s", key, value)
log.Printf("addEnvironmentVariable: current ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
if key == "" {
log.Printf("addEnvironmentVariable: key is empty, showing error")
dialog.ShowError(fmt.Errorf("environment variable name cannot be empty"), ui.launcher.window)
return
}
// Check if key already exists
for _, envVar := range ui.envVarsData {
if envVar.Key == key {
log.Printf("addEnvironmentVariable: key %s already exists, showing error", key)
dialog.ShowError(fmt.Errorf("environment variable '%s' already exists", key), ui.launcher.window)
return
}
}
log.Printf("addEnvironmentVariable: adding new env var %s=%s", key, value)
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
log.Printf("addEnvironmentVariable: after adding, ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
fyne.Do(func() {
if ui.updateEnvironmentDisplay != nil {
ui.updateEnvironmentDisplay()
}
// Clear input fields
ui.newEnvKeyEntry.SetText("")
ui.newEnvValueEntry.SetText("")
})
log.Printf("addEnvironmentVariable: calling saveEnvironmentVariables")
// Save to configuration
ui.saveEnvironmentVariables()
}
// removeEnvironmentVariable removes an environment variable by index
func (ui *LauncherUI) removeEnvironmentVariable(index int) {
if index >= 0 && index < len(ui.envVarsData) {
ui.envVarsData = append(ui.envVarsData[:index], ui.envVarsData[index+1:]...)
fyne.Do(func() {
if ui.updateEnvironmentDisplay != nil {
ui.updateEnvironmentDisplay()
}
})
ui.saveEnvironmentVariables()
}
}
// saveEnvironmentVariables saves environment variables to the configuration
func (ui *LauncherUI) saveEnvironmentVariables() {
if ui.launcher == nil {
log.Printf("saveEnvironmentVariables: launcher is nil")
return
}
config := ui.launcher.GetConfig()
log.Printf("saveEnvironmentVariables: before - Environment vars: %v", config.EnvironmentVars)
config.EnvironmentVars = make(map[string]string)
for _, envVar := range ui.envVarsData {
config.EnvironmentVars[envVar.Key] = envVar.Value
log.Printf("saveEnvironmentVariables: adding %s=%s", envVar.Key, envVar.Value)
}
log.Printf("saveEnvironmentVariables: after - Environment vars: %v", config.EnvironmentVars)
log.Printf("saveEnvironmentVariables: calling SetConfig with %d environment variables", len(config.EnvironmentVars))
err := ui.launcher.SetConfig(config)
if err != nil {
log.Printf("saveEnvironmentVariables: failed to save config: %v", err)
} else {
log.Printf("saveEnvironmentVariables: config saved successfully")
}
}
// confirmDeleteEnvironmentVariable shows confirmation dialog for deleting an environment variable
func (ui *LauncherUI) confirmDeleteEnvironmentVariable(index int) {
if index >= 0 && index < len(ui.envVarsData) {
envVar := ui.envVarsData[index]
dialog.ShowConfirm("Remove Environment Variable",
fmt.Sprintf("Remove environment variable '%s'?", envVar.Key),
func(remove bool) {
if remove {
ui.removeEnvironmentVariable(index)
}
}, ui.launcher.window)
}
}
// setupBindings sets up event handlers for UI elements
func (ui *LauncherUI) setupBindings() {
// Start/Stop button
ui.startStopButton.OnTapped = func() {
if ui.launcher.IsRunning() {
ui.stopLocalAI()
} else {
ui.startLocalAI()
}
}
// WebUI button
ui.webUIButton.OnTapped = func() {
ui.openWebUI()
}
ui.webUIButton.Disable() // Disabled until LocalAI is running
// Update button
ui.updateButton.OnTapped = func() {
ui.checkForUpdates()
}
// Log level selection
ui.logLevelSelect.OnChanged = func(selected string) {
if ui.launcher != nil {
config := ui.launcher.GetConfig()
config.LogLevel = selected
ui.launcher.SetConfig(config)
}
}
}
// startLocalAI starts the LocalAI service
func (ui *LauncherUI) startLocalAI() {
fyne.Do(func() {
ui.startStopButton.Disable()
})
ui.UpdateStatus("Starting LocalAI...")
go func() {
err := ui.launcher.StartLocalAI()
if err != nil {
ui.UpdateStatus("Failed to start: " + err.Error())
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.Do(func() {
ui.startStopButton.SetText("Stop LocalAI")
ui.webUIButton.Enable()
})
}
fyne.Do(func() {
ui.startStopButton.Enable()
})
}()
}
// stopLocalAI stops the LocalAI service
func (ui *LauncherUI) stopLocalAI() {
fyne.Do(func() {
ui.startStopButton.Disable()
})
ui.UpdateStatus("Stopping LocalAI...")
go func() {
err := ui.launcher.StopLocalAI()
if err != nil {
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.Do(func() {
ui.startStopButton.SetText("Start LocalAI")
ui.webUIButton.Disable()
})
}
fyne.Do(func() {
ui.startStopButton.Enable()
})
}()
}
// openWebUI opens the LocalAI WebUI in the default browser
func (ui *LauncherUI) openWebUI() {
webURL := ui.launcher.GetWebUIURL()
parsedURL, err := url.Parse(webURL)
if err != nil {
dialog.ShowError(err, ui.launcher.window)
return
}
// Open URL in default browser
fyne.CurrentApp().OpenURL(parsedURL)
}
// saveConfiguration saves the current configuration
func (ui *LauncherUI) saveConfiguration() {
log.Printf("saveConfiguration: starting to save configuration")
config := ui.launcher.GetConfig()
log.Printf("saveConfiguration: current config Environment vars: %v", config.EnvironmentVars)
log.Printf("saveConfiguration: ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
config.ModelsPath = ui.modelsPathEntry.Text
config.BackendsPath = ui.backendsPathEntry.Text
config.Address = ui.addressEntry.Text
config.LogLevel = ui.logLevelSelect.Selected
config.StartOnBoot = ui.startOnBootCheck.Checked
// Ensure environment variables are included in the configuration
config.EnvironmentVars = make(map[string]string)
for _, envVar := range ui.envVarsData {
config.EnvironmentVars[envVar.Key] = envVar.Value
log.Printf("saveConfiguration: adding env var %s=%s", envVar.Key, envVar.Value)
}
log.Printf("saveConfiguration: final config Environment vars: %v", config.EnvironmentVars)
err := ui.launcher.SetConfig(config)
if err != nil {
log.Printf("saveConfiguration: failed to save config: %v", err)
dialog.ShowError(err, ui.launcher.window)
} else {
log.Printf("saveConfiguration: config saved successfully")
dialog.ShowInformation("Configuration", "Configuration saved successfully", ui.launcher.window)
}
}
// checkForUpdates checks for available updates
func (ui *LauncherUI) checkForUpdates() {
fyne.Do(func() {
ui.updateButton.Disable()
})
ui.UpdateStatus("Checking for updates...")
go func() {
available, version, err := ui.launcher.CheckForUpdates()
if err != nil {
ui.UpdateStatus("Failed to check updates: " + err.Error())
fyne.DoAndWait(func() {
dialog.ShowError(err, ui.launcher.window)
})
} else if available {
ui.latestVersion = version // Store the latest version
ui.UpdateStatus("Update available: " + version)
fyne.Do(func() {
if ui.downloadButton != nil {
ui.downloadButton.Enable()
}
})
ui.NotifyUpdateAvailable(version)
} else {
ui.UpdateStatus("No updates available")
fyne.DoAndWait(func() {
dialog.ShowInformation("Updates", "You are running the latest version", ui.launcher.window)
})
}
fyne.Do(func() {
ui.updateButton.Enable()
})
}()
}
// downloadUpdate downloads the latest update
func (ui *LauncherUI) downloadUpdate() {
// Use stored version or check for updates
version := ui.latestVersion
if version == "" {
_, v, err := ui.launcher.CheckForUpdates()
if err != nil {
dialog.ShowError(err, ui.launcher.window)
return
}
version = v
ui.latestVersion = version
}
if version == "" {
dialog.ShowError(fmt.Errorf("no version information available"), ui.launcher.window)
return
}
// Disable buttons during download
if ui.downloadButton != nil {
fyne.Do(func() {
ui.downloadButton.Disable()
})
}
fyne.Do(func() {
ui.progressBar.Show()
ui.progressBar.SetValue(0)
})
ui.UpdateStatus("Downloading update " + version + "...")
go func() {
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
ui.progressBar.SetValue(progress)
})
// Update status with percentage
percentage := int(progress * 100)
ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
})
fyne.Do(func() {
ui.progressBar.Hide()
})
// Re-enable buttons after download
if ui.downloadButton != nil {
fyne.Do(func() {
ui.downloadButton.Enable()
})
}
if err != nil {
fyne.DoAndWait(func() {
ui.UpdateStatus("Failed to download update: " + err.Error())
dialog.ShowError(err, ui.launcher.window)
})
} else {
fyne.DoAndWait(func() {
ui.UpdateStatus("Update downloaded successfully")
dialog.ShowInformation("Update", "Update downloaded successfully. Please restart the launcher to use the new version.", ui.launcher.window)
})
}
}()
}
// UpdateStatus updates the status label
func (ui *LauncherUI) UpdateStatus(status string) {
if ui.statusLabel != nil {
fyne.Do(func() {
ui.statusLabel.SetText(status)
})
}
}
// OnLogUpdate handles new log content
func (ui *LauncherUI) OnLogUpdate(logLine string) {
if ui.logText != nil {
fyne.Do(func() {
currentText := ui.logText.Text
ui.logText.SetText(currentText + logLine)
// Auto-scroll to bottom (simplified)
ui.logText.CursorRow = len(ui.logText.Text)
})
}
}
// NotifyUpdateAvailable shows an update notification
func (ui *LauncherUI) NotifyUpdateAvailable(version string) {
if ui.launcher != nil && ui.launcher.window != nil {
fyne.DoAndWait(func() {
dialog.ShowConfirm("Update Available",
"A new version ("+version+") is available. Would you like to download it?",
func(confirmed bool) {
if confirmed {
ui.downloadUpdate()
}
}, ui.launcher.window)
})
}
}
// LoadConfiguration loads the current configuration into UI elements
func (ui *LauncherUI) LoadConfiguration() {
if ui.launcher == nil {
log.Printf("UI LoadConfiguration: launcher is nil")
return
}
config := ui.launcher.GetConfig()
log.Printf("UI LoadConfiguration: loading config - ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
config.ModelsPath, config.BackendsPath, config.Address, config.LogLevel)
log.Printf("UI LoadConfiguration: Environment vars: %v", config.EnvironmentVars)
ui.modelsPathEntry.SetText(config.ModelsPath)
ui.backendsPathEntry.SetText(config.BackendsPath)
ui.addressEntry.SetText(config.Address)
ui.logLevelSelect.SetSelected(config.LogLevel)
ui.startOnBootCheck.SetChecked(config.StartOnBoot)
// Load environment variables
ui.envVarsData = []EnvVar{}
for key, value := range config.EnvironmentVars {
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
}
if ui.updateEnvironmentDisplay != nil {
fyne.Do(func() {
ui.updateEnvironmentDisplay()
})
}
// Update version display
version := ui.launcher.GetCurrentVersion()
ui.versionLabel.SetText("Version: " + version)
log.Printf("UI LoadConfiguration: configuration loaded successfully")
}
// showDownloadProgress shows a progress window for downloading LocalAI
func (ui *LauncherUI) showDownloadProgress(version, title string) {
fyne.DoAndWait(func() {
// Create progress window using the launcher's app
progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
progressWindow.Resize(fyne.NewSize(400, 250))
progressWindow.CenterOnScreen()
// Progress bar
progressBar := widget.NewProgressBar()
progressBar.SetValue(0)
// Status label
statusLabel := widget.NewLabel("Preparing download...")
// Release notes button
releaseNotesButton := widget.NewButton("View Release Notes", func() {
releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
if err != nil {
log.Printf("Failed to parse URL: %v", err)
return
}
ui.launcher.app.OpenURL(releaseNotesURL)
})
// Progress container
progressContainer := container.NewVBox(
widget.NewLabel(title),
progressBar,
statusLabel,
widget.NewSeparator(),
releaseNotesButton,
)
progressWindow.SetContent(progressContainer)
progressWindow.Show()
// Start download in background
go func() {
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
// Update progress bar
fyne.Do(func() {
progressBar.SetValue(progress)
percentage := int(progress * 100)
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
})
})
// Handle completion
fyne.Do(func() {
if err != nil {
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
// Show error dialog
dialog.ShowError(err, progressWindow)
} else {
statusLabel.SetText("Download completed successfully!")
progressBar.SetValue(1.0)
// Show success dialog
dialog.ShowConfirm("Installation Complete",
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
func(close bool) {
progressWindow.Close()
// Update status
ui.UpdateStatus("LocalAI installed successfully")
}, progressWindow)
}
})
}()
})
}
// UpdateRunningState updates UI based on LocalAI running state
func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
fyne.Do(func() {
if isRunning {
ui.startStopButton.SetText("Stop LocalAI")
ui.webUIButton.Enable()
} else {
ui.startStopButton.SetText("Start LocalAI")
ui.webUIButton.Disable()
}
})
}
// ShowWelcomeWindow displays the welcome window with helpful information
func (ui *LauncherUI) ShowWelcomeWindow() {
if ui.launcher == nil || ui.launcher.window == nil {
log.Printf("Cannot show welcome window: launcher or window is nil")
return
}
fyne.DoAndWait(func() {
// Create welcome window
welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher")
welcomeWindow.Resize(fyne.NewSize(600, 500))
welcomeWindow.CenterOnScreen()
welcomeWindow.SetCloseIntercept(func() {
welcomeWindow.Close()
})
// Title
titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!")
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
titleLabel.Alignment = fyne.TextAlignCenter
// Welcome message
welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system.
What you can do:
• Start and stop LocalAI server
• Configure models and backends paths
• Set environment variables
• Check for updates automatically
• Access LocalAI WebUI when running
Getting Started:
1. Configure your models and backends paths
2. Click "Start LocalAI" to begin
3. Use "Open WebUI" to access the interface
4. Check the system tray for quick access`
welcomeLabel := widget.NewLabel(welcomeText)
welcomeLabel.Wrapping = fyne.TextWrapWord
// Useful links section
linksTitle := widget.NewLabel("Useful Links:")
linksTitle.TextStyle = fyne.TextStyle{Bold: true}
// Create link buttons
docsButton := widget.NewButton("📚 Documentation", func() {
ui.openURL("https://localai.io/docs/")
})
githubButton := widget.NewButton("🐙 GitHub Repository", func() {
ui.openURL("https://github.com/mudler/LocalAI")
})
modelsButton := widget.NewButton("🤖 Model Gallery", func() {
ui.openURL("https://localai.io/models/")
})
communityButton := widget.NewButton("💬 Community", func() {
ui.openURL("https://discord.gg/XgwjKptP7Z")
})
// Checkbox to disable welcome window
dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) {
if ui.launcher != nil {
config := ui.launcher.GetConfig()
v := !checked
config.ShowWelcome = &v
ui.launcher.SetConfig(config)
}
})
config := ui.launcher.GetConfig()
if config.ShowWelcome != nil {
dontShowAgainCheck.SetChecked(*config.ShowWelcome)
}
// Close button
closeButton := widget.NewButton("Get Started", func() {
welcomeWindow.Close()
})
closeButton.Importance = widget.HighImportance
// Layout
linksContainer := container.NewVBox(
linksTitle,
docsButton,
githubButton,
modelsButton,
communityButton,
)
content := container.NewVBox(
titleLabel,
widget.NewSeparator(),
welcomeLabel,
widget.NewSeparator(),
linksContainer,
widget.NewSeparator(),
dontShowAgainCheck,
widget.NewSeparator(),
closeButton,
)
welcomeWindow.SetContent(content)
welcomeWindow.Show()
})
}
// openURL opens a URL in the default browser
func (ui *LauncherUI) openURL(urlString string) {
parsedURL, err := url.Parse(urlString)
if err != nil {
log.Printf("Failed to parse URL %s: %v", urlString, err)
return
}
fyne.CurrentApp().OpenURL(parsedURL)
}

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.0 KiB

View File

@@ -1,92 +0,0 @@
package main
import (
"log"
"os"
"os/signal"
"syscall"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/driver/desktop"
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
)
func main() {
// Create the application with unique ID
myApp := app.NewWithID("com.localai.launcher")
myApp.SetIcon(resourceIconPng)
myWindow := myApp.NewWindow("LocalAI Launcher")
myWindow.Resize(fyne.NewSize(800, 600))
// Create the launcher UI
ui := coreLauncher.NewLauncherUI()
// Initialize the launcher with UI context
launcher := coreLauncher.NewLauncher(ui, myWindow, myApp)
// Setup the UI
content := ui.CreateMainUI(launcher)
myWindow.SetContent(content)
// Setup window close behavior - minimize to tray instead of closing
myWindow.SetCloseIntercept(func() {
myWindow.Hide()
})
// Setup system tray using Fyne's built-in approach``
if desk, ok := myApp.(desktop.App); ok {
// Create a dynamic systray manager
systray := coreLauncher.NewSystrayManager(launcher, myWindow, desk, myApp, resourceIconPng)
launcher.SetSystray(systray)
}
// Setup signal handling for graceful shutdown
setupSignalHandling(launcher)
// Initialize the launcher state
go func() {
if err := launcher.Initialize(); err != nil {
log.Printf("Failed to initialize launcher: %v", err)
if launcher.GetUI() != nil {
launcher.GetUI().UpdateStatus("Failed to initialize: " + err.Error())
}
} else {
// Load configuration into UI
launcher.GetUI().LoadConfiguration()
launcher.GetUI().UpdateStatus("Ready")
// Show welcome window if configured to do so
config := launcher.GetConfig()
if *config.ShowWelcome {
launcher.GetUI().ShowWelcomeWindow()
}
}
}()
// Run the application in background (window only shown when "Settings" is clicked)
myApp.Run()
}
// setupSignalHandling sets up signal handlers for graceful shutdown
func setupSignalHandling(launcher *coreLauncher.Launcher) {
// Create a channel to receive OS signals
sigChan := make(chan os.Signal, 1)
// Register for interrupt and terminate signals
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Handle signals in a separate goroutine
go func() {
sig := <-sigChan
log.Printf("Received signal %v, shutting down gracefully...", sig)
// Perform cleanup
if err := launcher.Shutdown(); err != nil {
log.Printf("Error during shutdown: %v", err)
}
// Exit the application
os.Exit(0)
}()
}

View File

@@ -56,12 +56,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
@@ -87,13 +87,13 @@ func New(opts ...config.AppOption) (*Application, error) {
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
return nil, err
}
}

View File

@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err

View File

@@ -78,12 +78,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
b = c.Batch
}
flashAttention := "auto"
if c.FlashAttention != nil {
flashAttention = *c.FlashAttention
}
f16 := false
if c.F16 != nil {
f16 = *c.F16
@@ -172,7 +166,7 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
MMProj: c.MMProj,
FlashAttention: flashAttention,
FlashAttention: c.FlashAttention,
CacheTypeKey: c.CacheTypeK,
CacheTypeValue: c.CacheTypeV,
NoKVOffload: c.NoKVOffloading,

View File

@@ -12,7 +12,7 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
@@ -34,7 +34,6 @@ func ModelTranscription(audio, language string, translate bool, diarize bool, ml
Dst: audio,
Language: language,
Translate: translate,
Diarize: diarize,
Threads: uint32(*modelConfig.Threads),
})
if err != nil {

View File

@@ -7,7 +7,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(
@@ -22,18 +22,12 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
_, err := inferenceModel.GenerateVideo(
appConfig.Context,
&proto.GenerateVideoRequest{
Height: height,
Width: width,
Prompt: prompt,
NegativePrompt: negativePrompt,
StartImage: startImage,
EndImage: endImage,
NumFrames: numFrames,
Fps: fps,
Seed: seed,
CfgScale: cfgScale,
Step: step,
Dst: dst,
Height: height,
Width: width,
Prompt: prompt,
StartImage: startImage,
EndImage: endImage,
Dst: dst,
})
return err
}

View File

@@ -6,7 +6,6 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/core/gallery"
@@ -101,8 +100,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -5,7 +5,6 @@ import (
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http"
)
@@ -46,7 +45,5 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)
signals.Handler(nil)
return appHTTP.Listen(e.Address)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/p2p"
)
@@ -20,7 +19,5 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
signals.Handler(nil)
return fs.Start(context.Background())
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
"github.com/schollz/progressbar/v3"
@@ -126,8 +125,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -10,11 +10,9 @@ import (
"github.com/mudler/LocalAI/core/application"
cli_api "github.com/mudler/LocalAI/core/cli/api"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@@ -75,16 +73,9 @@ type RunCMD struct {
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
Version bool
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.Version {
fmt.Println(internal.Version)
return nil
}
os.MkdirAll(r.BackendsPath, 0750)
os.MkdirAll(r.ModelsPath, 0750)
@@ -225,8 +216,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
return err
}
// Catch signals from the OS requesting us to exit, and stop all backends
signals.Handler(app.ModelLoader())
return appHTTP.Listen(r.Address)
}

View File

@@ -1,25 +0,0 @@
package signals
import (
"os"
"os/signal"
"syscall"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
func Handler(m *model.ModelLoader) {
// Catch signals from the OS requesting us to exit, and stop all backends
go func(m *model.ModelLoader) {
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
<-c
if m != nil {
if err := m.StopAllGRPC(); err != nil {
log.Error().Err(err).Msg("error while stopping all grpc backends")
}
}
os.Exit(1)
}(m)
}

View File

@@ -20,7 +20,6 @@ type TranscriptCMD struct {
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"c" help:"Translate the transcription to english"`
Diarize bool `short:"d" help:"Mark speaker turns"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
}
@@ -57,7 +56,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
}
}()
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
if err != nil {
return err
}

View File

@@ -2,7 +2,6 @@ package worker
type WorkerFlags struct {
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
}

View File

@@ -1,7 +1,6 @@
package worker
import (
"encoding/json"
"errors"
"fmt"
"os"
@@ -10,10 +9,7 @@ import (
"syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -24,10 +20,9 @@ type LLamaCPP struct {
const (
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
llamaCPPGalleryName = "llama-cpp"
)
func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
log.Warn().Msgf("Failed listing system backends: %s", err)
@@ -35,19 +30,9 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
}
log.Debug().Msgf("System backends: %v", backends)
backend, ok := backends.Get(llamaCPPGalleryName)
backend, ok := backends.Get("llama-cpp")
if !ok {
ml := model.NewModelLoader(systemState, true)
var gals []config.Gallery
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err
}
return "", errors.New("llama-cpp backend not found, install it first")
}
backendPath := filepath.Dir(backend.RunFile)
@@ -76,7 +61,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
if err != nil {
return err
}
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
return err
}
@@ -84,9 +69,6 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
args := strings.Split(r.ExtraLLamaCPPArgs, " ")
args = append([]string{grpcProcess}, args...)
signals.Handler(nil)
return syscall.Exec(
grpcProcess,
args,

View File

@@ -9,7 +9,6 @@ import (
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/signals"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/pkg/system"
"github.com/phayes/freeport"
@@ -70,7 +69,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
for {
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
return
@@ -107,8 +106,6 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
}
}
signals.Handler(nil)
for {
time.Sleep(1 * time.Second)
}

View File

@@ -164,10 +164,10 @@ type LLMConfig struct {
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj" json:"mmproj"`
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
FlashAttention bool `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
ModelType string `yaml:"type" json:"type"`

View File

@@ -1,5 +1,3 @@
// Package gallery provides installation and registration utilities for LocalAI backends,
// including meta-backend resolution based on system capabilities.
package gallery
import (
@@ -7,7 +5,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/mudler/LocalAI/core/config"
@@ -23,12 +20,6 @@ const (
runFile = "run.sh"
)
// backendCandidate represents an installed concrete backend option for a given alias
type backendCandidate struct {
name string
runFile string
}
// readBackendMetadata reads the metadata JSON file for a backend
func readBackendMetadata(backendPath string) (*BackendMetadata, error) {
metadataPath := filepath.Join(backendPath, metadataFile)
@@ -67,8 +58,8 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
return nil
}
// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
// Installs a model from the gallery
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -108,7 +99,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(systemState, bestBackend, downloadStatus); err != nil {
return err
}
@@ -133,10 +124,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
return InstallBackend(systemState, backend, downloadStatus)
}
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(systemState *system.SystemState, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -194,7 +185,7 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
return fmt.Errorf("failed to write metadata for backend %q: %v", name, err)
}
return RegisterBackends(systemState, modelLoader)
return nil
}
func DeleteBackendFromSystem(systemState *system.SystemState, name string) error {
@@ -291,18 +282,23 @@ func (b SystemBackends) GetAll() []SystemBackend {
}
func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) {
// Gather backends from system and user paths, then resolve alias conflicts by capability.
potentialBackends, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return nil, err
}
backends := make(SystemBackends)
// System-provided backends
if systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath); err == nil {
systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath)
if err == nil {
// system backends are special, they are provided by the system and not managed by LocalAI
for _, systemBackend := range systemBackends {
if systemBackend.IsDir() {
run := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
if _, err := os.Stat(run); err == nil {
systemBackendRunFile := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
if _, err := os.Stat(systemBackendRunFile); err == nil {
backends[systemBackend.Name()] = SystemBackend{
Name: systemBackend.Name(),
RunFile: run,
RunFile: filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile),
IsMeta: false,
IsSystem: true,
Metadata: nil,
@@ -311,103 +307,63 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
}
}
} else {
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
log.Warn().Err(err).Msg("Failed to read system backends, but that's ok, we will just use the backends managed by LocalAI")
}
// User-managed backends and alias collection
entries, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return nil, err
}
for _, potentialBackend := range potentialBackends {
if potentialBackend.IsDir() {
potentialBackendRunFile := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), runFile)
aliasGroups := make(map[string][]backendCandidate)
metaMap := make(map[string]*BackendMetadata)
var metadata *BackendMetadata
for _, e := range entries {
if !e.IsDir() {
continue
}
dir := e.Name()
run := filepath.Join(systemState.Backend.BackendsPath, dir, runFile)
var metadata *BackendMetadata
metadataPath := filepath.Join(systemState.Backend.BackendsPath, dir, metadataFile)
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
metadata = &BackendMetadata{Name: dir}
} else {
m, rerr := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, dir))
if rerr != nil {
return nil, rerr
}
if m == nil {
metadata = &BackendMetadata{Name: dir}
// If metadata file does not exist, we just use the directory name
// and we do not fill the other metadata (such as potential backend Aliases)
metadataFilePath := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), metadataFile)
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
metadata = &BackendMetadata{
Name: potentialBackend.Name(),
}
} else {
metadata = m
}
}
metaMap[dir] = metadata
// Concrete backend entry
if _, err := os.Stat(run); err == nil {
backends[dir] = SystemBackend{
Name: dir,
RunFile: run,
IsMeta: false,
Metadata: metadata,
}
}
// Alias candidates
if metadata.Alias != "" {
aliasGroups[metadata.Alias] = append(aliasGroups[metadata.Alias], backendCandidate{name: dir, runFile: run})
}
// Meta backends indirection
if metadata.MetaBackendFor != "" {
backends[metadata.Name] = SystemBackend{
Name: metadata.Name,
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
IsMeta: true,
Metadata: metadata,
}
}
}
// Resolve aliases using system capability preferences
tokens := systemState.BackendPreferenceTokens()
for alias, cands := range aliasGroups {
chosen := backendCandidate{}
// Try preference tokens
for _, t := range tokens {
for _, c := range cands {
if strings.Contains(strings.ToLower(c.name), t) && c.runFile != "" {
chosen = c
break
// Check for alias in metadata
metadata, err = readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name()))
if err != nil {
return nil, err
}
}
if chosen.runFile != "" {
break
}
}
// Fallback: first runnable
if chosen.runFile == "" {
for _, c := range cands {
if c.runFile != "" {
chosen = c
break
if !backends.Exists(potentialBackend.Name()) {
// We don't want to override aliases if already set, and if we are meta backend
if _, err := os.Stat(potentialBackendRunFile); err == nil {
backends[potentialBackend.Name()] = SystemBackend{
Name: potentialBackend.Name(),
RunFile: potentialBackendRunFile,
IsMeta: false,
Metadata: metadata,
}
}
}
if metadata == nil {
continue
}
if metadata.Alias != "" {
backends[metadata.Alias] = SystemBackend{
Name: metadata.Alias,
RunFile: potentialBackendRunFile,
IsMeta: false,
Metadata: metadata,
}
}
if metadata.MetaBackendFor != "" {
backends[metadata.Name] = SystemBackend{
Name: metadata.Name,
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
IsMeta: true,
Metadata: metadata,
}
}
}
if chosen.runFile == "" {
continue
}
md := metaMap[chosen.name]
backends[alias] = SystemBackend{
Name: alias,
RunFile: chosen.runFile,
IsMeta: false,
Metadata: md,
}
}

View File

@@ -7,7 +7,6 @@ import (
"runtime"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -18,79 +17,10 @@ const (
testImage = "quay.io/mudler/tests:localai-backend-test"
)
var _ = Describe("Runtime capability-based backend selection", func() {
var tempDir string
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "gallery-caps-*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
It("ListSystemBackends prefers optimal alias candidate", func() {
// Arrange two installed backends sharing the same alias
must := func(err error) { Expect(err).NotTo(HaveOccurred()) }
cpuDir := filepath.Join(tempDir, "cpu-llama-cpp")
must(os.MkdirAll(cpuDir, 0o750))
cpuMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cpu-llama-cpp"}
b, _ := json.Marshal(cpuMeta)
must(os.WriteFile(filepath.Join(cpuDir, "metadata.json"), b, 0o644))
must(os.WriteFile(filepath.Join(cpuDir, "run.sh"), []byte(""), 0o755))
cudaDir := filepath.Join(tempDir, "cuda12-llama-cpp")
must(os.MkdirAll(cudaDir, 0o750))
cudaMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cuda12-llama-cpp"}
b, _ = json.Marshal(cudaMeta)
must(os.WriteFile(filepath.Join(cudaDir, "metadata.json"), b, 0o644))
must(os.WriteFile(filepath.Join(cudaDir, "run.sh"), []byte(""), 0o755))
// Default system: alias should point to CPU
sysDefault, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expect(aliasBack.RunFile).To(Equal(filepath.Join(cpuDir, "run.sh")))
// concrete entries remain
_, ok = backs.Get("cpu-llama-cpp")
Expect(ok).To(BeTrue())
_, ok = backs.Get("cuda12-llama-cpp")
Expect(ok).To(BeTrue())
// NVIDIA system: alias should point to CUDA
// Force capability to nvidia to make the test deterministic on platforms like darwin/arm64 (which default to metal)
os.Setenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY", "nvidia")
defer os.Unsetenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY")
sysNvidia, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expect(aliasBack.RunFile).To(Equal(filepath.Join(cudaDir, "run.sh")))
})
})
var _ = Describe("Gallery Backends", func() {
var (
tempDir string
galleries []config.Gallery
ml *model.ModelLoader
systemState *system.SystemState
tempDir string
galleries []config.Gallery
)
BeforeEach(func() {
@@ -105,9 +35,6 @@ var _ = Describe("Gallery Backends", func() {
URL: "https://gist.githubusercontent.com/mudler/71d5376bc2aa168873fa519fa9f4bd56/raw/0557f9c640c159fa8e4eab29e8d98df6a3d6e80f/backend-gallery.yaml",
},
}
systemState, err = system.GetSystemState(system.WithBackendPath(tempDir))
Expect(err).NotTo(HaveOccurred())
ml = model.NewModelLoader(systemState, true)
})
AfterEach(func() {
@@ -116,13 +43,21 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -298,7 +233,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -378,7 +313,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -462,7 +397,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -561,7 +496,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -593,7 +528,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -626,7 +561,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -647,7 +582,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

View File

@@ -11,7 +11,6 @@ import (
"github.com/mudler/LocalAI/core/config"
lconfig "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/utils"
@@ -74,7 +73,6 @@ type PromptTemplate struct {
func InstallModelFromGallery(
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
applyModel := func(model *GalleryModel) error {
@@ -133,7 +131,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}

View File

@@ -88,7 +88,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))

View File

@@ -70,24 +70,6 @@ func infoButton(m *gallery.GalleryModel) elem.Node {
)
}
func getConfigButton(galleryName string) elem.Node {
return elem.Button(
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right ml-2 inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
"hx-swap": "outerHTML",
"hx-post": "browse/config/model/" + galleryName,
},
elem.I(
attrs.Props{
"class": "fa-solid fa-download pr-2",
},
),
elem.Text("Get Config"),
)
}
func deleteButton(galleryID string) elem.Node {
return elem.Button(
attrs.Props{

View File

@@ -339,12 +339,7 @@ func modelActionItems(m *gallery.GalleryModel, processTracker ProcessTracker, ga
reInstallButton(m.ID()),
deleteButton(m.ID()),
)),
// otherwise, show the install button, and the get config button
elem.Node(elem.Div(
attrs.Props{},
getConfigButton(m.ID()),
installButton(m.ID()),
)),
installButton(m.ID()),
),
),
),

View File

@@ -226,33 +226,3 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
return c.JSON(response)
}
}
// ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Preload the models
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to preload models: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Return success response
response := ModelResponse{
Success: true,
Message: "Model configurations reloaded successfully",
}
return c.Status(fiber.StatusOK).JSON(response)
}
}

View File

@@ -61,7 +61,7 @@ func downloadFile(url string) (string, error) {
*/
// VideoEndpoint
// @Summary Creates a video given a prompt.
// @Param request body schema.VideoRequest true "query params"
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
@@ -166,23 +166,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
baseURL := c.BaseURL()
fn, err := backend.VideoGeneration(
height,
width,
input.Prompt,
input.NegativePrompt,
src,
input.EndImage,
output,
input.NumFrames,
input.FPS,
input.Seed,
input.CFGScale,
input.Step,
ml,
*config,
appConfig,
)
fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig)
if err != nil {
return err
}

View File

@@ -31,7 +31,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
var id, textContentToReturn string
var created int
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
@@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
responses <- initialMessage
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
@@ -65,19 +65,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return true
})
close(responses)
return err
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
result := ""
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})
if err != nil {
return err
}
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
@@ -98,7 +95,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return err
return
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@@ -172,7 +169,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
close(responses)
return err
}
return func(c *fiber.Ctx) error {
@@ -227,11 +223,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
if err != nil {
return err
}
switch d.Type {
case "json_object":
if d.Type == "json_object" {
input.Grammar = functions.JSONBNF
case "json_schema":
} else if d.Type == "json_schema" {
d := schema.JsonSchemaRequest{}
dat, err := json.Marshal(config.ResponseFormatMap)
if err != nil {
@@ -332,69 +326,31 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
c.Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse)
ended := make(chan error, 1)
go func() {
if !shouldUseFn {
ended <- process(predInput, input, config, ml, responses, extraUsage)
} else {
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
}
}()
if !shouldUseFn {
go process(predInput, input, config, ml, responses, extraUsage)
} else {
go processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
}
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
return
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
}
w.Flush()
}
finishReason := "stop"
@@ -422,9 +378,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
log.Debug().Msgf("Stream ended")
}))
return nil
// no streaming mode

View File

@@ -30,7 +30,7 @@ import (
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
created := int(time.Now().Unix())
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@@ -59,9 +59,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
responses <- resp
return true
}
_, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
close(responses)
return err
}
return func(c *fiber.Ctx) error {
@@ -122,37 +121,18 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
responses := make(chan schema.OpenAIResponse)
ended := make(chan error)
go func() {
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}()
go process(id, predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error())
w.Flush()
break LOOP
}
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
@@ -173,7 +153,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return <-ended
return nil
}
var result []schema.Choice

View File

@@ -31,7 +31,6 @@ import (
const (
localSampleRate = 16000
remoteSampleRate = 24000
vadModel = "silero-vad-ggml"
)
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
@@ -234,7 +233,7 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
// TODO: The API has no way to configure the VAD model or other models that make up a pipeline to fake any-to-any
// So possibly we could have a way to configure a composite model that can be used in situations where any-to-any is expected
pipeline := config.Pipeline{
VAD: vadModel,
VAD: "silero-vad",
Transcription: session.InputAudioTranscription.Model,
}
@@ -569,7 +568,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model {
pipeline := config.Pipeline{
VAD: vadModel,
VAD: "silero-vad",
Transcription: trUpd.Model,
}

View File

@@ -36,8 +36,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
return fiber.ErrBadRequest
}
diarize := c.FormValue("diarize", "false") != "false"
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
@@ -69,7 +67,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, ml, *config, appConfig)
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig)
if err != nil {
return err
}

View File

@@ -59,9 +59,6 @@ func RegisterLocalAIRoutes(router *fiber.App,
// Custom model edit endpoint
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
// Reload models endpoint
router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
}
router.Post("/v1/detection",

Some files were not shown because too many files have changed in this diff Show More