mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 03:02:38 -05:00
Compare commits
168 Commits
feat/trans
...
v3.6.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fb95686af | ||
|
|
4132085c01 | ||
|
|
c14f1ffcfd | ||
|
|
07cca4b69a | ||
|
|
dd927c36f6 | ||
|
|
052f42e926 | ||
|
|
30d43588ab | ||
|
|
d21ec22f74 | ||
|
|
04fecd634a | ||
|
|
33c14198db | ||
|
|
967c2727e3 | ||
|
|
f41f30ad92 | ||
|
|
e77340e8a5 | ||
|
|
d51a3090f7 | ||
|
|
1bf3bc932c | ||
|
|
564a47da4e | ||
|
|
c37ee93ff2 | ||
|
|
f4b65db4e7 | ||
|
|
f5fa8e6649 | ||
|
|
570e39bdcf | ||
|
|
2ebe37b671 | ||
|
|
dca685f784 | ||
|
|
84ebf2a2c9 | ||
|
|
ce5662ba90 | ||
|
|
9878f27813 | ||
|
|
f2b9452ec4 | ||
|
|
585da99c52 | ||
|
|
fd4f432079 | ||
|
|
238c68c57b | ||
|
|
04fbf5cb82 | ||
|
|
c85d559919 | ||
|
|
b5efc4f89e | ||
|
|
3f9c09a4c5 | ||
|
|
4a84660475 | ||
|
|
737248256e | ||
|
|
0ae334fc62 | ||
|
|
36c373b7c9 | ||
|
|
6afcb932b7 | ||
|
|
357bf571a3 | ||
|
|
e74ade9ebb | ||
|
|
f7f26b8efa | ||
|
|
75eb98f8bd | ||
|
|
c337e7baf7 | ||
|
|
660bd45be8 | ||
|
|
c27da0a0f6 | ||
|
|
ac043ed9ba | ||
|
|
2e0d66a1c8 | ||
|
|
41a0f361eb | ||
|
|
d3c5c02837 | ||
|
|
ae3d8fb0c4 | ||
|
|
902e47f0b0 | ||
|
|
50bb78fd24 | ||
|
|
542f07ab2d | ||
|
|
77c5acb9db | ||
|
|
44bbf4d778 | ||
|
|
633c12f93d | ||
|
|
6f24135f1d | ||
|
|
b72aa7b4fa | ||
|
|
e94e725479 | ||
|
|
e4ac7b14a3 | ||
|
|
ddb39c73f2 | ||
|
|
264b09fb1e | ||
|
|
36dd45df51 | ||
|
|
e5599f87b8 | ||
|
|
e89b5cc0e3 | ||
|
|
10bf1084cc | ||
|
|
b08ae559b3 | ||
|
|
aa7cb7e18c | ||
|
|
eadd3d4e46 | ||
|
|
2a18206033 | ||
|
|
39798d734e | ||
|
|
d0e99562af | ||
|
|
6410c99bf2 | ||
|
|
55766d269b | ||
|
|
ffa0ad1eac | ||
|
|
623789a29e | ||
|
|
2b9a3d32c9 | ||
|
|
f8b71dc5d0 | ||
|
|
1d3331b5cb | ||
|
|
2c0b9c6349 | ||
|
|
3c6c976755 | ||
|
|
ebbcba342a | ||
|
|
0de75519dc | ||
|
|
37f5e4f5c1 | ||
|
|
ffa934b959 | ||
|
|
59311d8b1e | ||
|
|
d9e25af7b5 | ||
|
|
e4f8b63b40 | ||
|
|
1364ae9be6 | ||
|
|
cfd6a9150d | ||
|
|
cd352d0c5f | ||
|
|
8d47309695 | ||
|
|
5f6fc02a55 | ||
|
|
0b528458d8 | ||
|
|
caab380c5d | ||
|
|
8a3a362504 | ||
|
|
07238eb743 | ||
|
|
e905e90dd7 | ||
|
|
08432d49e5 | ||
|
|
e51e2aacb9 | ||
|
|
9c3d85fc28 | ||
|
|
007ca647a7 | ||
|
|
59af928379 | ||
|
|
dbc2bb561b | ||
|
|
c72c85dcac | ||
|
|
ef984901e6 | ||
|
|
9911ec84a3 | ||
|
|
1956681d4c | ||
|
|
326f6e5ccb | ||
|
|
302958efd6 | ||
|
|
3dc86b247d | ||
|
|
5ec724af06 | ||
|
|
1f1e156bf0 | ||
|
|
df625e366a | ||
|
|
9e6685ac9c | ||
|
|
90c818aa71 | ||
|
|
034b9b691b | ||
|
|
ba52822e5c | ||
|
|
eb30f6c090 | ||
|
|
caba098959 | ||
|
|
3c75ea1e0e | ||
|
|
c5f911812f | ||
|
|
d82922786a | ||
|
|
d9e9bb4c0e | ||
|
|
657027bec6 | ||
|
|
2f5635308d | ||
|
|
63b5338dbd | ||
|
|
3150174962 | ||
|
|
4330fdce33 | ||
|
|
fef8583144 | ||
|
|
d4d6a56a4f | ||
|
|
2900a601a0 | ||
|
|
43e0437db6 | ||
|
|
976c159fdb | ||
|
|
969922ffec | ||
|
|
739573e41b | ||
|
|
dbdf2908ad | ||
|
|
317f8641dc | ||
|
|
54ff70e451 | ||
|
|
723f01c87e | ||
|
|
79a41a5e07 | ||
|
|
d0b6aa3f7d | ||
|
|
ad99399c6e | ||
|
|
e6ebfd3ba1 | ||
|
|
ead00a28b9 | ||
|
|
9621edb4c5 | ||
|
|
7ce92f0646 | ||
|
|
6a4ab3c1e0 | ||
|
|
83b85494c1 | ||
|
|
df6a80b38d | ||
|
|
21faa4114b | ||
|
|
e35ad56602 | ||
|
|
3be8b2d8e1 | ||
|
|
900745bb4d | ||
|
|
15a7fc7e9a | ||
|
|
03dddec538 | ||
|
|
3d34386712 | ||
|
|
1b3f66018b | ||
|
|
4381e892b8 | ||
|
|
3c3f477854 | ||
|
|
f8a8cf3e95 | ||
|
|
0fc88b3cdf | ||
|
|
4993df81c3 | ||
|
|
599bc88c6c | ||
|
|
1a0d06f3db | ||
|
|
5e1a8b3621 | ||
|
|
960e51e527 | ||
|
|
195aa22e77 |
@@ -6,6 +6,10 @@ models
|
||||
backends
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
backend/go/*/.cache
|
||||
backend/go/*/sources
|
||||
backend/go/*/package
|
||||
examples/rwkv/models
|
||||
examples/**/models
|
||||
Dockerfile*
|
||||
|
||||
132
.github/workflows/backend.yml
vendored
132
.github/workflows/backend.yml
vendored
@@ -2,7 +2,6 @@
|
||||
name: 'build backend container images'
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
@@ -64,18 +63,6 @@ jobs:
|
||||
backend: "llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-transformers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "transformers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "11"
|
||||
cuda-minor-version: "7"
|
||||
@@ -124,6 +111,18 @@ jobs:
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-chatterbox'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "chatterbox"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# CUDA 11 additional backends
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "11"
|
||||
@@ -243,7 +242,7 @@ jobs:
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "diffusers"
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# CUDA 12 additional backends
|
||||
@@ -490,6 +489,18 @@ jobs:
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-l4t-kokoro'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "kokoro"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -776,7 +787,7 @@ jobs:
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-hipblas-whisper'
|
||||
tag-suffix: '-gpu-rocm-hipblas-whisper'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
@@ -871,7 +882,7 @@ jobs:
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
@@ -944,6 +955,18 @@ jobs:
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-l4t-arm64-chatterbox'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "chatterbox"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# runs out of space on the runner
|
||||
# - build-type: 'hipblas'
|
||||
# cuda-major-version: ""
|
||||
@@ -970,54 +993,41 @@ jobs:
|
||||
backend: "kitten-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
transformers-darwin:
|
||||
backend-jobs-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- backend: "diffusers"
|
||||
tag-suffix: "-metal-darwin-arm64-diffusers"
|
||||
build-type: "mps"
|
||||
- backend: "mlx"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx"
|
||||
build-type: "mps"
|
||||
- backend: "chatterbox"
|
||||
tag-suffix: "-metal-darwin-arm64-chatterbox"
|
||||
build-type: "mps"
|
||||
- backend: "mlx-vlm"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
|
||||
build-type: "mps"
|
||||
- backend: "mlx-audio"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-audio"
|
||||
build-type: "mps"
|
||||
- backend: "stablediffusion-ggml"
|
||||
tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "whisper"
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
with:
|
||||
backend: "transformers"
|
||||
build-type: "mps"
|
||||
backend: ${{ matrix.backend }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: "-metal-darwin-arm64-transformers"
|
||||
use-pip: true
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
diffusers-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
with:
|
||||
backend: "diffusers"
|
||||
build-type: "mps"
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: "-metal-darwin-arm64-diffusers"
|
||||
use-pip: true
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
mlx-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
with:
|
||||
backend: "mlx"
|
||||
build-type: "mps"
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx"
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
mlx-vlm-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
with:
|
||||
backend: "mlx-vlm"
|
||||
build-type: "mps"
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
|
||||
tag-suffix: ${{ matrix.tag-suffix }}
|
||||
lang: ${{ matrix.lang || 'python' }}
|
||||
use-pip: ${{ matrix.backend == 'diffusers' }}
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
30
.github/workflows/backend_build_darwin.yml
vendored
30
.github/workflows/backend_build_darwin.yml
vendored
@@ -16,6 +16,10 @@ on:
|
||||
description: 'Use pip to install dependencies'
|
||||
default: false
|
||||
type: boolean
|
||||
lang:
|
||||
description: 'Programming language (e.g. go)'
|
||||
default: 'python'
|
||||
type: string
|
||||
go-version:
|
||||
description: 'Go version to use'
|
||||
default: '1.24.x'
|
||||
@@ -49,26 +53,26 @@ jobs:
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: false
|
||||
|
||||
|
||||
# You can test your matrix by printing the current Go version
|
||||
- name: Display Go version
|
||||
run: go version
|
||||
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
|
||||
|
||||
|
||||
- name: Build ${{ inputs.backend }}-darwin
|
||||
run: |
|
||||
make protogen-go
|
||||
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-python-backend
|
||||
|
||||
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
|
||||
|
||||
- name: Upload ${{ inputs.backend }}.tar
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -85,20 +89,20 @@ jobs:
|
||||
with:
|
||||
name: ${{ inputs.backend }}-tar
|
||||
path: .
|
||||
|
||||
|
||||
- name: Install crane
|
||||
run: |
|
||||
curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz
|
||||
sudo mv crane /usr/local/bin/
|
||||
|
||||
|
||||
- name: Log in to DockerHub
|
||||
run: |
|
||||
echo "${{ secrets.dockerPassword }}" | crane auth login docker.io -u "${{ secrets.dockerUsername }}" --password-stdin
|
||||
|
||||
|
||||
- name: Log in to quay.io
|
||||
run: |
|
||||
echo "${{ secrets.quayPassword }}" | crane auth login quay.io -u "${{ secrets.quayUsername }}" --password-stdin
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -112,7 +116,7 @@ jobs:
|
||||
flavor: |
|
||||
latest=auto
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: quaymeta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -126,13 +130,13 @@ jobs:
|
||||
flavor: |
|
||||
latest=auto
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
|
||||
- name: Push Docker image (DockerHub)
|
||||
run: |
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do
|
||||
crane push ${{ inputs.backend }}.tar $tag
|
||||
done
|
||||
|
||||
|
||||
- name: Push Docker image (Quay)
|
||||
run: |
|
||||
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
|
||||
|
||||
20
.github/workflows/backend_pr.yml
vendored
20
.github/workflows/backend_pr.yml
vendored
@@ -12,7 +12,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }}
|
||||
has-backends: ${{ steps.set-matrix.outputs.has-backends }}
|
||||
has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
@@ -56,3 +58,21 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
|
||||
backend-jobs-darwin:
|
||||
needs: generate-matrix
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
if: needs.generate-matrix.outputs.has-backends-darwin == 'true'
|
||||
with:
|
||||
backend: ${{ matrix.backend }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: ${{ matrix.tag-suffix }}
|
||||
lang: ${{ matrix.lang || 'python' }}
|
||||
use-pip: ${{ matrix.backend == 'diffusers' }}
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix-darwin) }}
|
||||
|
||||
44
.github/workflows/build-test.yaml
vendored
44
.github/workflows/build-test.yaml
vendored
@@ -21,3 +21,47 @@ jobs:
|
||||
- name: Run GoReleaser
|
||||
run: |
|
||||
make dev-dist
|
||||
launcher-build-darwin:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
- name: Build launcher for macOS ARM64
|
||||
run: |
|
||||
make build-launcher-darwin
|
||||
ls -liah dist
|
||||
- name: Upload macOS launcher artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: launcher-macos
|
||||
path: dist/
|
||||
retention-days: 30
|
||||
|
||||
launcher-build-linux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
- name: Build launcher for Linux
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
||||
make build-launcher-linux
|
||||
- name: Upload Linux launcher artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: launcher-linux
|
||||
path: local-ai-launcher-linux.tar.xz
|
||||
retention-days: 30
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -9,4 +9,4 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
3
.github/workflows/localaibot_automerge.yml
vendored
3
.github/workflows/localaibot_automerge.yml
vendored
@@ -6,7 +6,8 @@ permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
packages: read
|
||||
|
||||
issues: write # for Homebrew/actions/post-comment
|
||||
actions: write # to dispatch publish workflow
|
||||
jobs:
|
||||
dependabot:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
40
.github/workflows/release.yaml
vendored
40
.github/workflows/release.yaml
vendored
@@ -23,4 +23,42 @@ jobs:
|
||||
version: v2.11.0
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
launcher-build-darwin:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
- name: Build launcher for macOS ARM64
|
||||
run: |
|
||||
make build-launcher-darwin
|
||||
- name: Upload DMG to Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: ./dist/LocalAI.dmg
|
||||
launcher-build-linux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
- name: Build launcher for Linux
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
||||
make build-launcher-linux
|
||||
- name: Upload Linux launcher artifacts
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: ./local-ai-launcher-linux.tar.xz
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.8
|
||||
uses: securego/gosec@v2.22.9
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
2
.github/workflows/stalebot.yml
vendored
2
.github/workflows/stalebot.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9
|
||||
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -24,7 +24,7 @@ go-bert
|
||||
|
||||
# LocalAI build binary
|
||||
LocalAI
|
||||
local-ai
|
||||
/local-ai
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
|
||||
@@ -8,7 +8,7 @@ source:
|
||||
enabled: true
|
||||
name_template: '{{ .ProjectName }}-{{ .Tag }}-source'
|
||||
builds:
|
||||
-
|
||||
- main: ./cmd/local-ai
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
ldflags:
|
||||
|
||||
14
Dockerfile
14
Dockerfile
@@ -78,6 +78,16 @@ RUN <<EOT bash
|
||||
fi
|
||||
EOT
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu2204-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
@@ -100,6 +110,10 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
||||
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
|
||||
; fi
|
||||
|
||||
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
|
||||
|
||||
# Cuda
|
||||
|
||||
52
Makefile
52
Makefile
@@ -2,6 +2,7 @@ GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
GOVET=$(GOCMD) vet
|
||||
BINARY_NAME=local-ai
|
||||
LAUNCHER_BINARY_NAME=local-ai-launcher
|
||||
|
||||
GORELEASER?=
|
||||
|
||||
@@ -90,7 +91,17 @@ build: protogen-go install-go-tools ## Build the project
|
||||
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
|
||||
$(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET})
|
||||
rm -rf $(BINARY_NAME) || true
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai
|
||||
|
||||
build-launcher: ## Build the launcher application
|
||||
$(info ${GREEN}I local-ai launcher build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
|
||||
rm -rf $(LAUNCHER_BINARY_NAME) || true
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(LAUNCHER_BINARY_NAME) ./cmd/launcher
|
||||
|
||||
build-all: build build-launcher ## Build both server and launcher
|
||||
|
||||
dev-dist:
|
||||
$(GORELEASER) build --snapshot --clean
|
||||
@@ -106,8 +117,8 @@ run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
|
||||
test-models/testmodel.ggml:
|
||||
mkdir test-models
|
||||
mkdir test-dir
|
||||
mkdir -p test-models
|
||||
mkdir -p test-dir
|
||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||
@@ -358,6 +369,9 @@ backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
|
||||
backends/kokoro: docker-build-kokoro docker-save-kokoro build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
|
||||
|
||||
backends/chatterbox: docker-build-chatterbox docker-save-chatterbox build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/chatterbox.tar)"
|
||||
|
||||
backends/llama-cpp-darwin: build
|
||||
bash ./scripts/build/llama-cpp-darwin.sh
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
||||
@@ -365,6 +379,9 @@ backends/llama-cpp-darwin: build
|
||||
build-darwin-python-backend: build
|
||||
bash ./scripts/build/python-darwin.sh
|
||||
|
||||
build-darwin-go-backend: build
|
||||
bash ./scripts/build/golang-darwin.sh
|
||||
|
||||
backends/mlx:
|
||||
BACKEND=mlx $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)"
|
||||
@@ -377,6 +394,14 @@ backends/mlx-vlm:
|
||||
BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
|
||||
|
||||
backends/mlx-audio:
|
||||
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
|
||||
|
||||
backends/stablediffusion-ggml-darwin:
|
||||
BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
|
||||
|
||||
backend-images:
|
||||
mkdir -p backend-images
|
||||
|
||||
@@ -404,6 +429,9 @@ docker-build-kitten-tts:
|
||||
docker-save-kitten-tts: backend-images
|
||||
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
||||
|
||||
docker-save-chatterbox: backend-images
|
||||
docker save local-ai-backend:chatterbox -o backend-images/chatterbox.tar
|
||||
|
||||
docker-build-kokoro:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
||||
|
||||
@@ -471,7 +499,7 @@ docker-build-bark:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
|
||||
|
||||
docker-build-chatterbox:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox ./backend
|
||||
|
||||
docker-build-exllama2:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
|
||||
@@ -507,3 +535,19 @@ docs-clean:
|
||||
.PHONY: docs
|
||||
docs: docs/static/gallery.html
|
||||
cd docs && hugo serve
|
||||
|
||||
########################################################
|
||||
## Platform-specific builds
|
||||
########################################################
|
||||
|
||||
## fyne cross-platform build
|
||||
build-launcher-darwin: build-launcher
|
||||
go run github.com/tiagomelo/macos-dmg-creator/cmd/createdmg@latest \
|
||||
--appName "LocalAI" \
|
||||
--appBinaryPath "$(LAUNCHER_BINARY_NAME)" \
|
||||
--bundleIdentifier "com.localai.launcher" \
|
||||
--iconPath "core/http/static/logo.png" \
|
||||
--outputDir "dist/"
|
||||
|
||||
build-launcher-linux:
|
||||
cd cmd/launcher && go run fyne.io/tools/cmd/fyne@latest package -os linux -icon ../../core/http/static/logo.png --executable $(LAUNCHER_BINARY_NAME)-linux && mv launcher.tar.xz ../../$(LAUNCHER_BINARY_NAME)-linux.tar.xz
|
||||
|
||||
65
README.md
65
README.md
@@ -43,7 +43,7 @@
|
||||
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🥽 Demo](https://demo.localai.io) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
[](https://t.me/localaiofficial_bot)
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
@@ -110,6 +110,12 @@ curl https://localai.io/install.sh | sh
|
||||
|
||||
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
|
||||
|
||||
### macOS Download:
|
||||
|
||||
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||
</a>
|
||||
|
||||
Or run with docker:
|
||||
|
||||
### CPU only image:
|
||||
@@ -233,6 +239,60 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
||||
- 🔊 Voice activity detection (Silero-VAD support)
|
||||
- 🌍 Integrated WebUI!
|
||||
|
||||
## 🧩 Supported Backends & Acceleration
|
||||
|
||||
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
|
||||
|
||||
### Text Generation & Language Models
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **llama.cpp** | LLM inference in C/C++ | CUDA 11/12, ROCm, Intel SYCL, Vulkan, Metal, CPU |
|
||||
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12, ROCm, Intel |
|
||||
| **transformers** | HuggingFace transformers framework | CUDA 11/12, ROCm, Intel, CPU |
|
||||
| **exllama2** | GPTQ inference library | CUDA 12 |
|
||||
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
|
||||
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
|
||||
|
||||
### Audio & Speech Processing
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12, ROCm, Intel SYCL, Vulkan, CPU |
|
||||
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12, ROCm, Intel, CPU |
|
||||
| **bark** | Text-to-audio generation | CUDA 12, ROCm, Intel |
|
||||
| **bark-cpp** | C++ implementation of Bark | CUDA, Metal, CPU |
|
||||
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12, ROCm, Intel, CPU |
|
||||
| **kokoro** | Lightweight TTS model | CUDA 12, ROCm, Intel, CPU |
|
||||
| **chatterbox** | Production-grade TTS | CUDA 11/12, CPU |
|
||||
| **piper** | Fast neural TTS system | CPU |
|
||||
| **kitten-tts** | Kitten TTS models | CPU |
|
||||
| **silero-vad** | Voice Activity Detection | CPU |
|
||||
|
||||
### Image & Video Generation
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12, Intel SYCL, Vulkan, CPU |
|
||||
| **diffusers** | HuggingFace diffusion models | CUDA 11/12, ROCm, Intel, Metal, CPU |
|
||||
|
||||
### Specialized AI Tasks
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **rfdetr** | Real-time object detection | CUDA 12, Intel, CPU |
|
||||
| **rerankers** | Document reranking API | CUDA 11/12, ROCm, Intel, CPU |
|
||||
| **local-store** | Vector database | CPU |
|
||||
| **huggingface** | HuggingFace API integration | API-based |
|
||||
|
||||
### Hardware Acceleration Matrix
|
||||
|
||||
| Acceleration Type | Supported Backends | Hardware Support |
|
||||
|-------------------|-------------------|------------------|
|
||||
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
|
||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
|
||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
|
||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||
| **NVIDIA Jetson** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI |
|
||||
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
|
||||
|
||||
### 🔗 Community and integrations
|
||||
|
||||
@@ -247,6 +307,9 @@ WebUIs:
|
||||
Model galleries
|
||||
- https://github.com/go-skynet/model-gallery
|
||||
|
||||
Voice:
|
||||
- https://github.com/richiejp/VoxInput
|
||||
|
||||
Other:
|
||||
- Helm chart https://github.com/go-skynet/helm-charts
|
||||
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
|
||||
|
||||
@@ -2,10 +2,10 @@ context_size: 4096
|
||||
f16: true
|
||||
backend: llama-cpp
|
||||
mmap: true
|
||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
||||
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||
name: gpt-4o
|
||||
parameters:
|
||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
||||
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
- <dummy32000>
|
||||
@@ -42,9 +42,9 @@ template:
|
||||
<|im_start|>assistant
|
||||
|
||||
download_files:
|
||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
||||
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||
@@ -2,10 +2,10 @@ context_size: 4096
|
||||
backend: llama-cpp
|
||||
f16: true
|
||||
mmap: true
|
||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
||||
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||
name: gpt-4o
|
||||
parameters:
|
||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
||||
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
- <dummy32000>
|
||||
@@ -42,9 +42,9 @@ template:
|
||||
<|im_start|>assistant
|
||||
|
||||
download_files:
|
||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
||||
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||
@@ -2,10 +2,10 @@ context_size: 4096
|
||||
backend: llama-cpp
|
||||
f16: true
|
||||
mmap: true
|
||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
||||
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||
name: gpt-4o
|
||||
parameters:
|
||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
||||
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
- <dummy32000>
|
||||
@@ -43,9 +43,9 @@ template:
|
||||
|
||||
|
||||
download_files:
|
||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
||||
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||
213
backend/README.md
Normal file
213
backend/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# LocalAI Backend Architecture
|
||||
|
||||
This directory contains the core backend infrastructure for LocalAI, including the gRPC protocol definition, multi-language Dockerfiles, and language-specific backend implementations.
|
||||
|
||||
## Overview
|
||||
|
||||
LocalAI uses a unified gRPC-based architecture that allows different programming languages to implement AI backends while maintaining consistent interfaces and capabilities. The backend system supports multiple hardware acceleration targets and provides a standardized way to integrate various AI models and frameworks.
|
||||
|
||||
## Architecture Components
|
||||
|
||||
### 1. Protocol Definition (`backend.proto`)
|
||||
|
||||
The `backend.proto` file defines the gRPC service interface that all backends must implement. This ensures consistency across different language implementations and provides a contract for communication between LocalAI core and backend services.
|
||||
|
||||
#### Core Services
|
||||
|
||||
- **Text Generation**: `Predict`, `PredictStream` for LLM inference
|
||||
- **Embeddings**: `Embedding` for text vectorization
|
||||
- **Image Generation**: `GenerateImage` for stable diffusion and image models
|
||||
- **Audio Processing**: `AudioTranscription`, `TTS`, `SoundGeneration`
|
||||
- **Video Generation**: `GenerateVideo` for video synthesis
|
||||
- **Object Detection**: `Detect` for computer vision tasks
|
||||
- **Vector Storage**: `StoresSet`, `StoresGet`, `StoresFind` for RAG operations
|
||||
- **Reranking**: `Rerank` for document relevance scoring
|
||||
- **Voice Activity Detection**: `VAD` for audio segmentation
|
||||
|
||||
#### Key Message Types
|
||||
|
||||
- **`PredictOptions`**: Comprehensive configuration for text generation
|
||||
- **`ModelOptions`**: Model loading and configuration parameters
|
||||
- **`Result`**: Standardized response format
|
||||
- **`StatusResponse`**: Backend health and memory usage information
|
||||
|
||||
### 2. Multi-Language Dockerfiles
|
||||
|
||||
The backend system provides language-specific Dockerfiles that handle the build environment and dependencies for different programming languages:
|
||||
|
||||
- `Dockerfile.python`
|
||||
- `Dockerfile.golang`
|
||||
- `Dockerfile.llama-cpp`
|
||||
|
||||
### 3. Language-Specific Implementations
|
||||
|
||||
#### Python Backends (`python/`)
|
||||
- **transformers**: Hugging Face Transformers framework
|
||||
- **vllm**: High-performance LLM inference
|
||||
- **mlx**: Apple Silicon optimization
|
||||
- **diffusers**: Stable Diffusion models
|
||||
- **Audio**: bark, coqui, faster-whisper, kitten-tts
|
||||
- **Vision**: mlx-vlm, rfdetr
|
||||
- **Specialized**: rerankers, chatterbox, kokoro
|
||||
|
||||
#### Go Backends (`go/`)
|
||||
- **whisper**: OpenAI Whisper speech recognition in Go with GGML cpp backend (whisper.cpp)
|
||||
- **stablediffusion-ggml**: Stable Diffusion in Go with GGML Cpp backend
|
||||
- **huggingface**: Hugging Face model integration
|
||||
- **piper**: Text-to-speech synthesis Golang with C bindings using rhaspy/piper
|
||||
- **bark-cpp**: Bark TTS models Golang with Cpp bindings
|
||||
- **local-store**: Vector storage backend
|
||||
|
||||
#### C++ Backends (`cpp/`)
|
||||
- **llama-cpp**: Llama.cpp integration
|
||||
- **grpc**: GRPC utilities and helpers
|
||||
|
||||
## Hardware Acceleration Support
|
||||
|
||||
### CUDA (NVIDIA)
|
||||
- **Versions**: CUDA 11.x, 12.x
|
||||
- **Features**: cuBLAS, cuDNN, TensorRT optimization
|
||||
- **Targets**: x86_64, ARM64 (Jetson)
|
||||
|
||||
### ROCm (AMD)
|
||||
- **Features**: HIP, rocBLAS, MIOpen
|
||||
- **Targets**: AMD GPUs with ROCm support
|
||||
|
||||
### Intel
|
||||
- **Features**: oneAPI, Intel Extension for PyTorch
|
||||
- **Targets**: Intel GPUs, XPUs, CPUs
|
||||
|
||||
### Vulkan
|
||||
- **Features**: Cross-platform GPU acceleration
|
||||
- **Targets**: Windows, Linux, Android, macOS
|
||||
|
||||
### Apple Silicon
|
||||
- **Features**: MLX framework, Metal Performance Shaders
|
||||
- **Targets**: M1/M2/M3 Macs
|
||||
|
||||
## Backend Registry (`index.yaml`)
|
||||
|
||||
The `index.yaml` file serves as a central registry for all available backends, providing:
|
||||
|
||||
- **Metadata**: Name, description, license, icons
|
||||
- **Capabilities**: Hardware targets and optimization profiles
|
||||
- **Tags**: Categorization for discovery
|
||||
- **URLs**: Source code and documentation links
|
||||
|
||||
## Building Backends
|
||||
|
||||
### Prerequisites
|
||||
- Docker with multi-architecture support
|
||||
- Appropriate hardware drivers (CUDA, ROCm, etc.)
|
||||
- Build tools (make, cmake, compilers)
|
||||
|
||||
### Build Commands
|
||||
|
||||
Example of build commands with Docker
|
||||
|
||||
```bash
|
||||
# Build Python backend
|
||||
docker build -f backend/Dockerfile.python \
|
||||
--build-arg BACKEND=transformers \
|
||||
--build-arg BUILD_TYPE=cublas12 \
|
||||
--build-arg CUDA_MAJOR_VERSION=12 \
|
||||
--build-arg CUDA_MINOR_VERSION=0 \
|
||||
-t localai-backend-transformers .
|
||||
|
||||
# Build Go backend
|
||||
docker build -f backend/Dockerfile.golang \
|
||||
--build-arg BACKEND=whisper \
|
||||
--build-arg BUILD_TYPE=cpu \
|
||||
-t localai-backend-whisper .
|
||||
|
||||
# Build C++ backend
|
||||
docker build -f backend/Dockerfile.llama-cpp \
|
||||
--build-arg BACKEND=llama-cpp \
|
||||
--build-arg BUILD_TYPE=cublas12 \
|
||||
-t localai-backend-llama-cpp .
|
||||
```
|
||||
|
||||
For ARM64/Mac builds, docker can't be used, and the makefile in the respective backend has to be used.
|
||||
|
||||
### Build Types
|
||||
|
||||
- **`cpu`**: CPU-only optimization
|
||||
- **`cublas11`**: CUDA 11.x with cuBLAS
|
||||
- **`cublas12`**: CUDA 12.x with cuBLAS
|
||||
- **`hipblas`**: ROCm with rocBLAS
|
||||
- **`intel`**: Intel oneAPI optimization
|
||||
- **`vulkan`**: Vulkan-based acceleration
|
||||
- **`metal`**: Apple Metal optimization
|
||||
|
||||
## Backend Development
|
||||
|
||||
### Creating a New Backend
|
||||
|
||||
1. **Choose Language**: Select Python, Go, or C++ based on requirements
|
||||
2. **Implement Interface**: Implement the gRPC service defined in `backend.proto`
|
||||
3. **Add Dependencies**: Create appropriate requirements files
|
||||
4. **Configure Build**: Set up Dockerfile and build scripts
|
||||
5. **Register Backend**: Add entry to `index.yaml`
|
||||
6. **Test Integration**: Verify gRPC communication and functionality
|
||||
|
||||
### Backend Structure
|
||||
|
||||
```
|
||||
backend-name/
|
||||
├── backend.py/go/cpp # Main implementation
|
||||
├── requirements.txt # Dependencies
|
||||
├── Dockerfile # Build configuration
|
||||
├── install.sh # Installation script
|
||||
├── run.sh # Execution script
|
||||
├── test.sh # Test script
|
||||
└── README.md # Backend documentation
|
||||
```
|
||||
|
||||
### Required gRPC Methods
|
||||
|
||||
At minimum, backends must implement:
|
||||
- `Health()` - Service health check
|
||||
- `LoadModel()` - Model loading and initialization
|
||||
- `Predict()` - Main inference endpoint
|
||||
- `Status()` - Backend status and metrics
|
||||
|
||||
## Integration with LocalAI Core
|
||||
|
||||
Backends communicate with LocalAI core through gRPC:
|
||||
|
||||
1. **Service Discovery**: Core discovers available backends
|
||||
2. **Model Loading**: Core requests model loading via `LoadModel`
|
||||
3. **Inference**: Core sends requests via `Predict` or specialized endpoints
|
||||
4. **Streaming**: Core handles streaming responses for real-time generation
|
||||
5. **Monitoring**: Core tracks backend health and performance
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memory Management
|
||||
- **Model Caching**: Efficient model loading and caching
|
||||
- **Batch Processing**: Optimize for multiple concurrent requests
|
||||
- **Memory Pinning**: GPU memory optimization for CUDA/ROCm
|
||||
|
||||
### Hardware Utilization
|
||||
- **Multi-GPU**: Support for tensor parallelism
|
||||
- **Mixed Precision**: FP16/BF16 for memory efficiency
|
||||
- **Kernel Fusion**: Optimized CUDA/ROCm kernels
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **GRPC Connection**: Verify backend service is running and accessible
|
||||
2. **Model Loading**: Check model paths and dependencies
|
||||
3. **Hardware Detection**: Ensure appropriate drivers and libraries
|
||||
4. **Memory Issues**: Monitor GPU memory usage and model sizes
|
||||
|
||||
## Contributing
|
||||
|
||||
When contributing to the backend system:
|
||||
|
||||
1. **Follow Protocol**: Implement the exact gRPC interface
|
||||
2. **Add Tests**: Include comprehensive test coverage
|
||||
3. **Document**: Provide clear usage examples
|
||||
4. **Optimize**: Consider performance and resource usage
|
||||
5. **Validate**: Test across different hardware targets
|
||||
@@ -242,7 +242,7 @@ message ModelOptions {
|
||||
|
||||
string Type = 49;
|
||||
|
||||
bool FlashAttention = 56;
|
||||
string FlashAttention = 56;
|
||||
bool NoKVOffload = 57;
|
||||
|
||||
string ModelPath = 59;
|
||||
@@ -276,6 +276,7 @@ message TranscriptRequest {
|
||||
string language = 3;
|
||||
uint32 threads = 4;
|
||||
bool translate = 5;
|
||||
bool diarize = 6;
|
||||
}
|
||||
|
||||
message TranscriptResult {
|
||||
@@ -305,22 +306,24 @@ message GenerateImageRequest {
|
||||
// Diffusers
|
||||
string EnableParameters = 10;
|
||||
int32 CLIPSkip = 11;
|
||||
|
||||
|
||||
// Reference images for models that support them (e.g., Flux Kontext)
|
||||
repeated string ref_images = 12;
|
||||
}
|
||||
|
||||
message GenerateVideoRequest {
|
||||
string prompt = 1;
|
||||
string start_image = 2; // Path or base64 encoded image for the start frame
|
||||
string end_image = 3; // Path or base64 encoded image for the end frame
|
||||
int32 width = 4;
|
||||
int32 height = 5;
|
||||
int32 num_frames = 6; // Number of frames to generate
|
||||
int32 fps = 7; // Frames per second
|
||||
int32 seed = 8;
|
||||
float cfg_scale = 9; // Classifier-free guidance scale
|
||||
string dst = 10; // Output path for the generated video
|
||||
string negative_prompt = 2; // Negative prompt for video generation
|
||||
string start_image = 3; // Path or base64 encoded image for the start frame
|
||||
string end_image = 4; // Path or base64 encoded image for the end frame
|
||||
int32 width = 5;
|
||||
int32 height = 6;
|
||||
int32 num_frames = 7; // Number of frames to generate
|
||||
int32 fps = 8; // Frames per second
|
||||
int32 seed = 9;
|
||||
float cfg_scale = 10; // Classifier-free guidance scale
|
||||
int32 step = 11; // Number of inference steps
|
||||
string dst = 12; // Output path for the generated video
|
||||
}
|
||||
|
||||
message TTSRequest {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=710dfc465a68f7443b87d9f792cffba00ed739fe
|
||||
LLAMA_VERSION?=d64c8104f090b27b1f99e8da5995ffcfa6b726e2
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
@@ -14,7 +14,7 @@ CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF -DLLAMA_OPENSSL=OFF
|
||||
endif
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
|
||||
@@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
params.cpuparams.n_threads = request->threads();
|
||||
params.n_gpu_layers = request->ngpulayers();
|
||||
params.n_batch = request->nbatch();
|
||||
params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size"
|
||||
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
||||
//params.n_parallel = 1;
|
||||
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
||||
@@ -304,7 +305,15 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
}
|
||||
params.use_mlock = request->mlock();
|
||||
params.use_mmap = request->mmap();
|
||||
params.flash_attn = request->flashattention();
|
||||
|
||||
if (request->flashattention() == "on" || request->flashattention() == "enabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
} else if (request->flashattention() == "auto") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
}
|
||||
|
||||
params.no_kv_offload = request->nokvoffload();
|
||||
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
||||
|
||||
@@ -693,7 +702,7 @@ public:
|
||||
*/
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt = body.at("prompt");
|
||||
json prompt = body.at("embeddings");
|
||||
|
||||
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
@@ -704,6 +713,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -717,9 +727,8 @@ public:
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING;
|
||||
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
@@ -735,9 +744,8 @@ public:
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
|
||||
error = true;
|
||||
}, [&]() {
|
||||
// NOTE: we should try to check when the writer is closed here
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -747,12 +755,36 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
}
|
||||
|
||||
std::vector<float> embeddings = responses[0].value("embedding", std::vector<float>());
|
||||
// loop the vector and set the embeddings results
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
embeddingResult->add_embeddings(embeddings[i]);
|
||||
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
|
||||
|
||||
// Process the responses and extract embeddings
|
||||
for (const auto & response_elem : responses) {
|
||||
// Check if the response has an "embedding" field
|
||||
if (response_elem.contains("embedding")) {
|
||||
json embedding_data = json_value(response_elem, "embedding", json::array());
|
||||
|
||||
if (embedding_data.is_array() && !embedding_data.empty()) {
|
||||
for (const auto & embedding_vector : embedding_data) {
|
||||
if (embedding_vector.is_array()) {
|
||||
for (const auto & embedding_value : embedding_vector) {
|
||||
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Check if the response itself contains the embedding data directly
|
||||
if (response_elem.is_array()) {
|
||||
for (const auto & embedding_value : response_elem) {
|
||||
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
@@ -770,11 +802,6 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
// Tokenize the query
|
||||
auto tokenized_query = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, request->query(), /* add_special */ false, true);
|
||||
if (tokenized_query.size() != 1) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must contain only a single prompt");
|
||||
}
|
||||
// Create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -786,10 +813,9 @@ public:
|
||||
documents.push_back(request->documents(i));
|
||||
}
|
||||
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query[0], tokenized_docs[i]);
|
||||
tasks.reserve(documents.size());
|
||||
for (size_t i = 0; i < documents.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, request->query(), documents[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
6
backend/go/stablediffusion-ggml/.gitignore
vendored
Normal file
6
backend/go/stablediffusion-ggml/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
package/
|
||||
sources/
|
||||
.cache/
|
||||
build/
|
||||
libgosd.so
|
||||
stablediffusion-ggml
|
||||
@@ -5,7 +5,11 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
add_subdirectory(./sources/stablediffusion-ggml.cpp)
|
||||
|
||||
add_library(gosd MODULE gosd.cpp)
|
||||
target_link_libraries(gosd PRIVATE stable-diffusion ggml stdc++fs)
|
||||
target_link_libraries(gosd PRIVATE stable-diffusion ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gosd PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
target_include_directories(gosd PUBLIC
|
||||
stable-diffusion.cpp
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=5900ef6605c6fbf7934239f795c13c97bc993853
|
||||
STABLEDIFFUSION_GGML_VERSION?=0ebe6fe118f125665939b27c89f34ed38716bff8
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
@@ -29,8 +29,6 @@ else ifeq ($(BUILD_TYPE),clblas)
|
||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DSD_HIPBLAS=ON -DGGML_HIPBLAS=ON
|
||||
# If it's OSX, DO NOT embed the metal library - -DGGML_METAL_EMBED_LIBRARY=ON requires further investigation
|
||||
# But if it's OSX without metal, disable it here
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DSD_VULKAN=ON -DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
@@ -74,10 +72,10 @@ libgosd.so: sources/stablediffusion-ggml.cpp CMakeLists.txt gosd.cpp gosd.h
|
||||
stablediffusion-ggml: main.go gosd.go libgosd.so
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
|
||||
|
||||
package:
|
||||
package: stablediffusion-ggml
|
||||
bash package.sh
|
||||
|
||||
build: stablediffusion-ggml package
|
||||
build: package
|
||||
|
||||
clean:
|
||||
rm -rf libgosd.o build stablediffusion-ggml
|
||||
rm -rf libgosd.so build stablediffusion-ggml package sources
|
||||
|
||||
@@ -4,17 +4,11 @@
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include "gosd.h"
|
||||
|
||||
// #include "preprocessing.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#define STB_IMAGE_STATIC
|
||||
#include "stb_image.h"
|
||||
@@ -29,7 +23,7 @@
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
"euler_a",
|
||||
"default",
|
||||
"euler",
|
||||
"heun",
|
||||
"dpm2",
|
||||
@@ -41,19 +35,27 @@ const char* sample_method_str[] = {
|
||||
"lcm",
|
||||
"ddim_trailing",
|
||||
"tcd",
|
||||
"euler_a",
|
||||
};
|
||||
|
||||
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
|
||||
|
||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||
const char* schedule_str[] = {
|
||||
const char* schedulers[] = {
|
||||
"default",
|
||||
"discrete",
|
||||
"karras",
|
||||
"exponential",
|
||||
"ays",
|
||||
"gits",
|
||||
"smoothstep",
|
||||
};
|
||||
|
||||
static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch");
|
||||
|
||||
sd_ctx_t* sd_c;
|
||||
// Moved from the context (load time) to generation time params
|
||||
scheduler_t scheduler = scheduler_t::DEFAULT;
|
||||
|
||||
sample_method_t sample_method;
|
||||
|
||||
@@ -105,7 +107,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *scheduler = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
char *lora_dir = model_path;
|
||||
bool lora_dir_allocated = false;
|
||||
@@ -133,7 +135,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
vae_path = optval;
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler = optval;
|
||||
scheduler_str = optval;
|
||||
}
|
||||
if (!strcmp(optname, "sampler")) {
|
||||
sampler = optval;
|
||||
@@ -166,26 +168,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
}
|
||||
if (sample_method_found == -1) {
|
||||
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
|
||||
sample_method_found = EULER_A;
|
||||
sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT;
|
||||
}
|
||||
sample_method = (sample_method_t)sample_method_found;
|
||||
|
||||
int schedule_found = -1;
|
||||
for (int d = 0; d < SCHEDULE_COUNT; d++) {
|
||||
if (!strcmp(scheduler, schedule_str[d])) {
|
||||
schedule_found = d;
|
||||
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
||||
|
||||
if (!strcmp(scheduler_str, schedulers[d])) {
|
||||
scheduler = (scheduler_t)d;
|
||||
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
|
||||
}
|
||||
}
|
||||
|
||||
if (schedule_found == -1) {
|
||||
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
|
||||
schedule_found = DEFAULT;
|
||||
}
|
||||
|
||||
schedule_t schedule = (schedule_t)schedule_found;
|
||||
|
||||
fprintf (stderr, "Creating context\n");
|
||||
sd_ctx_params_t ctx_params;
|
||||
sd_ctx_params_init(&ctx_params);
|
||||
@@ -199,13 +192,10 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.control_net_path = "";
|
||||
ctx_params.lora_model_dir = lora_dir;
|
||||
ctx_params.embedding_dir = "";
|
||||
ctx_params.stacked_id_embed_dir = "";
|
||||
ctx_params.vae_decode_only = false;
|
||||
ctx_params.vae_tiling = false;
|
||||
ctx_params.free_params_immediately = false;
|
||||
ctx_params.n_threads = threads;
|
||||
ctx_params.rng_type = STD_DEFAULT_RNG;
|
||||
ctx_params.schedule = schedule;
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
@@ -228,7 +218,49 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
||||
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled) {
|
||||
params->enabled = enabled;
|
||||
}
|
||||
|
||||
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y) {
|
||||
params->tile_size_x = tile_size_x;
|
||||
params->tile_size_y = tile_size_y;
|
||||
}
|
||||
|
||||
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y) {
|
||||
params->rel_size_x = rel_size_x;
|
||||
params->rel_size_y = rel_size_y;
|
||||
}
|
||||
|
||||
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap) {
|
||||
params->target_overlap = target_overlap;
|
||||
}
|
||||
|
||||
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params) {
|
||||
return ¶ms->vae_tiling_params;
|
||||
}
|
||||
|
||||
sd_img_gen_params_t* sd_img_gen_params_new(void) {
|
||||
sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc(sizeof(sd_img_gen_params_t));
|
||||
sd_img_gen_params_init(params);
|
||||
return params;
|
||||
}
|
||||
|
||||
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
|
||||
params->prompt = prompt;
|
||||
params->negative_prompt = negative_prompt;
|
||||
}
|
||||
|
||||
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
|
||||
params->width = width;
|
||||
params->height = height;
|
||||
}
|
||||
|
||||
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) {
|
||||
params->seed = seed;
|
||||
}
|
||||
|
||||
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
||||
|
||||
sd_image_t* results;
|
||||
|
||||
@@ -236,20 +268,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
|
||||
fprintf (stderr, "Generating image\n");
|
||||
|
||||
sd_img_gen_params_t p;
|
||||
sd_img_gen_params_init(&p);
|
||||
p->sample_params.guidance.txt_cfg = cfg_scale;
|
||||
p->sample_params.guidance.slg.layers = skip_layers.data();
|
||||
p->sample_params.guidance.slg.layer_count = skip_layers.size();
|
||||
p->sample_params.sample_method = sample_method;
|
||||
p->sample_params.sample_steps = steps;
|
||||
p->sample_params.scheduler = scheduler;
|
||||
|
||||
p.prompt = text;
|
||||
p.negative_prompt = negativeText;
|
||||
p.guidance.txt_cfg = cfg_scale;
|
||||
p.guidance.slg.layers = skip_layers.data();
|
||||
p.guidance.slg.layer_count = skip_layers.size();
|
||||
p.width = width;
|
||||
p.height = height;
|
||||
p.sample_method = sample_method;
|
||||
p.sample_steps = steps;
|
||||
p.seed = seed;
|
||||
p.input_id_images_path = "";
|
||||
int width = p->width;
|
||||
int height = p->height;
|
||||
|
||||
// Handle input image for img2img
|
||||
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
||||
@@ -298,13 +325,13 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
input_image_buffer = resized_image_buffer;
|
||||
}
|
||||
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
||||
p.strength = strength;
|
||||
p->init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
||||
p->strength = strength;
|
||||
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
||||
} else {
|
||||
// No input image, use empty image for text-to-image
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
||||
p.strength = 0.0f;
|
||||
p->init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
||||
p->strength = 0.0f;
|
||||
}
|
||||
|
||||
// Handle mask image for inpainting
|
||||
@@ -344,12 +371,12 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
mask_image_buffer = resized_mask_buffer;
|
||||
}
|
||||
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
||||
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
||||
fprintf(stderr, "Using inpainting with mask\n");
|
||||
} else {
|
||||
// No mask image, create default full mask
|
||||
default_mask_image_vec.resize(width * height, 255);
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
||||
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
@@ -407,13 +434,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
}
|
||||
|
||||
if (!ref_images_vec.empty()) {
|
||||
p.ref_images = ref_images_vec.data();
|
||||
p.ref_images_count = ref_images_vec.size();
|
||||
p->ref_images = ref_images_vec.data();
|
||||
p->ref_images_count = ref_images_vec.size();
|
||||
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
||||
}
|
||||
}
|
||||
|
||||
results = generate_image(sd_c, &p);
|
||||
results = generate_image(sd_c, p);
|
||||
|
||||
std::free(p);
|
||||
|
||||
if (results == NULL) {
|
||||
fprintf (stderr, "NO results\n");
|
||||
|
||||
@@ -22,7 +22,18 @@ type SDGGML struct {
|
||||
|
||||
var (
|
||||
LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
|
||||
GenImage func(text, negativeText string, width, height, steps int, seed int64, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
|
||||
GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
|
||||
|
||||
TilingParamsSetEnabled func(params uintptr, enabled bool)
|
||||
TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int)
|
||||
TilingParamsSetRelSizes func(params uintptr, relSizeX float32, relSizeY float32)
|
||||
TilingParamsSetTargetOverlap func(params uintptr, targetOverlap float32)
|
||||
|
||||
ImgGenParamsNew func() uintptr
|
||||
ImgGenParamsSetPrompts func(params uintptr, prompt string, negativePrompt string)
|
||||
ImgGenParamsSetDimensions func(params uintptr, width int, height int)
|
||||
ImgGenParamsSetSeed func(params uintptr, seed int64)
|
||||
ImgGenParamsGetVaeTilingParams func(params uintptr) uintptr
|
||||
)
|
||||
|
||||
// Copied from Purego internal/strings
|
||||
@@ -120,7 +131,15 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
// Default strength for img2img (0.75 is a good default)
|
||||
strength := float32(0.75)
|
||||
|
||||
ret := GenImage(t, negative, int(opts.Width), int(opts.Height), int(opts.Step), int64(opts.Seed), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
|
||||
// free'd by GenImage
|
||||
p := ImgGenParamsNew()
|
||||
ImgGenParamsSetPrompts(p, t, negative)
|
||||
ImgGenParamsSetDimensions(p, int(opts.Width), int(opts.Height))
|
||||
ImgGenParamsSetSeed(p, int64(opts.Seed))
|
||||
vaep := ImgGenParamsGetVaeTilingParams(p)
|
||||
TilingParamsSetEnabled(vaep, false)
|
||||
|
||||
ret := GenImage(p, int(opts.Step), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("inference failed")
|
||||
}
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
#include <cstdint>
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled);
|
||||
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y);
|
||||
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y);
|
||||
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap);
|
||||
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params);
|
||||
|
||||
sd_img_gen_params_t* sd_img_gen_params_new(void);
|
||||
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt);
|
||||
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height);
|
||||
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed);
|
||||
|
||||
int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
||||
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -11,14 +11,35 @@ var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&LoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&GenImage, gosd, "gen_image")
|
||||
libFuncs := []LibFuncs{
|
||||
{&LoadModel, "load_model"},
|
||||
{&GenImage, "gen_image"},
|
||||
{&TilingParamsSetEnabled, "sd_tiling_params_set_enabled"},
|
||||
{&TilingParamsSetTileSizes, "sd_tiling_params_set_tile_sizes"},
|
||||
{&TilingParamsSetRelSizes, "sd_tiling_params_set_rel_sizes"},
|
||||
{&TilingParamsSetTargetOverlap, "sd_tiling_params_set_target_overlap"},
|
||||
|
||||
{&ImgGenParamsNew, "sd_img_gen_params_new"},
|
||||
{&ImgGenParamsSetPrompts, "sd_img_gen_params_set_prompts"},
|
||||
{&ImgGenParamsSetDimensions, "sd_img_gen_params_set_dimensions"},
|
||||
{&ImgGenParamsSetSeed, "sd_img_gen_params_set_seed"},
|
||||
{&ImgGenParamsGetVaeTilingParams, "sd_img_gen_params_get_vae_tiling_params"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ CURDIR=$(dirname "$(realpath $0)")
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avrf $CURDIR/libgosd.so $CURDIR/package/
|
||||
cp -avrf $CURDIR/stablediffusion-ggml $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
cp -avf $CURDIR/libgosd.so $CURDIR/package/
|
||||
cp -avf $CURDIR/stablediffusion-ggml $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
@@ -43,6 +43,8 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
|
||||
7
backend/go/whisper/.gitignore
vendored
Normal file
7
backend/go/whisper/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
whisper
|
||||
libgowhisper.so
|
||||
|
||||
16
backend/go/whisper/CMakeLists.txt
Normal file
16
backend/go/whisper/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gowhisper LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/whisper.cpp)
|
||||
|
||||
add_library(gowhisper MODULE gowhisper.cpp)
|
||||
target_link_libraries(gowhisper PRIVATE whisper ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gosd PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gowhisper PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gowhisper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
@@ -1,110 +1,53 @@
|
||||
GOCMD=go
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
BUILD_TYPE?=
|
||||
CMAKE_ARGS?=
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=fc45bb86251f774ef817e89878bb4c2636c8a58f
|
||||
WHISPER_CPP_VERSION?=7849aff7a2e1f4234aa31b01a1870906d5431959
|
||||
|
||||
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
|
||||
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
|
||||
export WHISPER_INCLUDE_PATH=$(WHISPER_DIR)/include:$(WHISPER_DIR)/ggml/include
|
||||
export WHISPER_LIBRARY_PATH=$(WHISPER_DIR)/build/src/:$(WHISPER_DIR)/build/ggml/src
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
CGO_LDFLAGS_WHISPER?=
|
||||
CGO_LDFLAGS_WHISPER+=-lggml
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
|
||||
CUDA_LIBPATH?=/usr/local/cuda/lib64/
|
||||
|
||||
ONEAPI_VERSION?=2025.2
|
||||
|
||||
# IF native is false, we add -DGGML_NATIVE=OFF to CMAKE_ARGS
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
WHISPER_CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH) -L$(CUDA_LIBPATH)/stubs/ -lcuda
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
CGO_LDFLAGS_WHISPER+=-lcufft -lggml-cuda
|
||||
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-cuda/
|
||||
# If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
# to CMAKE_ARGS automatically
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib
|
||||
export STABLE_BUILD_TYPE=
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
# GPU_TARGETS ?= gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102
|
||||
# AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON
|
||||
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib -L$(CURRENT_MAKEFILE_DIR)/sources/whisper.cpp/build/ggml/src/ggml-hip/ -lggml-hip
|
||||
# CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=1
|
||||
CGO_LDFLAGS_WHISPER+=-lggml-vulkan -lvulkan
|
||||
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-vulkan/
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifeq ($(BUILD_TYPE),)
|
||||
BUILD_TYPE=metal
|
||||
endif
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
CGO_LDFLAGS_WHISPER+=-lggml-blas
|
||||
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-blas
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_USE_BF16=ON
|
||||
CMAKE_ARGS+=-DGGML_OPENMP=OFF
|
||||
CMAKE_ARGS+=-DWHISPER_BUILD_EXAMPLES=OFF
|
||||
CMAKE_ARGS+=-DWHISPER_BUILD_TESTS=OFF
|
||||
CMAKE_ARGS+=-DWHISPER_BUILD_SERVER=OFF
|
||||
CGO_LDFLAGS += -framework Accelerate
|
||||
CGO_LDFLAGS_WHISPER+=-lggml-metal -lggml-blas
|
||||
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-metal/:$(WHISPER_DIR)/build/ggml/src/ggml-blas
|
||||
endif
|
||||
TARGET+=--target ggml-metal
|
||||
endif
|
||||
|
||||
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||
export CC=icx
|
||||
export CXX=icpx
|
||||
CGO_LDFLAGS_WHISPER += -fsycl -L${DNNLROOT}/lib -rpath ${ONEAPI_ROOT}/${ONEAPI_VERSION}/lib -ldnnl ${MKLROOT}/lib/intel64/libmkl_sycl.a -fiopenmp -fopenmp-targets=spir64 -lOpenCL -lggml-sycl
|
||||
CGO_LDFLAGS_WHISPER += $(shell pkg-config --libs mkl-static-lp64-gomp)
|
||||
CGO_CXXFLAGS_WHISPER += -fiopenmp -fopenmp-targets=spir64
|
||||
CGO_CXXFLAGS_WHISPER += $(shell pkg-config --cflags mkl-static-lp64-gomp )
|
||||
export WHISPER_LIBRARY_PATH:=$(WHISPER_LIBRARY_PATH):$(WHISPER_DIR)/build/ggml/src/ggml-sycl/
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DCMAKE_CXX_FLAGS="-fsycl"
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL_F16=ON
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifneq ($(OS),Darwin)
|
||||
CGO_LDFLAGS_WHISPER+=-lgomp
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
## whisper
|
||||
sources/whisper.cpp:
|
||||
mkdir -p sources/whisper.cpp
|
||||
cd sources/whisper.cpp && \
|
||||
@@ -114,18 +57,21 @@ sources/whisper.cpp:
|
||||
git checkout $(WHISPER_CPP_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
sources/whisper.cpp/build/src/libwhisper.a: sources/whisper.cpp
|
||||
cd sources/whisper.cpp && cmake $(CMAKE_ARGS) $(WHISPER_CMAKE_ARGS) . -B ./build
|
||||
cd sources/whisper.cpp/build && cmake --build . --config Release
|
||||
libgowhisper.so: sources/whisper.cpp CMakeLists.txt gowhisper.cpp gowhisper.h
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build/libgowhisper.so ./
|
||||
|
||||
whisper: sources/whisper.cpp sources/whisper.cpp/build/src/libwhisper.a
|
||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(CURDIR)/sources/whisper.cpp/bindings/go
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="${WHISPER_INCLUDE_PATH}" LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" LD_LIBRARY_PATH="${WHISPER_LIBRARY_PATH}" \
|
||||
CGO_CXXFLAGS="$(CGO_CXXFLAGS_WHISPER)" \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o whisper ./
|
||||
whisper: main.go gowhisper.go libgowhisper.so
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
|
||||
|
||||
package:
|
||||
package: whisper
|
||||
bash package.sh
|
||||
|
||||
build: whisper package
|
||||
build: package
|
||||
|
||||
clean:
|
||||
rm -rf libgowhisper.o build whisper
|
||||
|
||||
154
backend/go/whisper/gowhisper.cpp
Normal file
154
backend/go/whisper/gowhisper.cpp
Normal file
@@ -0,0 +1,154 @@
|
||||
#include "gowhisper.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static struct whisper_context *ctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
|
||||
ctx = whisper_init_from_file_with_params(model_path, cparams);
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: Also failed to init model as transcriber\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
|
||||
whisper_full_params wparams =
|
||||
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.n_threads = threads;
|
||||
if (*lang != '\0')
|
||||
wparams.language = lang;
|
||||
else {
|
||||
wparams.language = nullptr;
|
||||
}
|
||||
|
||||
wparams.translate = translate;
|
||||
wparams.debug_mode = true;
|
||||
wparams.print_progress = true;
|
||||
wparams.tdrz_enable = tdrz;
|
||||
|
||||
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = whisper_full_n_segments(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
return whisper_full_get_segment_text(ctx, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) { return whisper_full_get_segment_t0(ctx, i); }
|
||||
|
||||
int64_t get_segment_t1(int i) { return whisper_full_get_segment_t1(ctx, i); }
|
||||
|
||||
int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
|
||||
|
||||
int32_t get_token_id(int i, int j) {
|
||||
return whisper_full_get_token_id(ctx, i, j);
|
||||
}
|
||||
|
||||
bool get_segment_speaker_turn_next(int i) {
|
||||
return whisper_full_get_segment_speaker_turn_next(ctx, i);
|
||||
}
|
||||
161
backend/go/whisper/gowhisper.go
Normal file
161
backend/go/whisper/gowhisper.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppNTokens func(i int) int
|
||||
CppGetTokenID func(i int, j int) int
|
||||
CppGetSegmentSpeakerTurnNext func(i int) bool
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
func (w *Whisper) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load Whisper VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load Whisper transcription model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen)
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i)
|
||||
t := CppGetSegmentEnd(i)
|
||||
txt := strings.Clone(CppGetSegmentText(i))
|
||||
tokens := make([]int32, CppNTokens(i))
|
||||
|
||||
if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
|
||||
txt += " [SPEAKER_TURN]"
|
||||
}
|
||||
|
||||
for j := range tokens {
|
||||
tokens[j] = int32(CppGetTokenID(i, j))
|
||||
}
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
Tokens: tokens,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
}, nil
|
||||
}
|
||||
17
backend/go/whisper/gowhisper.h
Normal file
17
backend/go/whisper/gowhisper.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
int n_tokens(int i);
|
||||
int32_t get_token_id(int i, int j);
|
||||
bool get_segment_speaker_turn_next(int i);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
@@ -12,7 +12,34 @@ var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
gosd, err := purego.Dlopen("./libgowhisper.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppNTokens, "n_tokens"},
|
||||
{&CppGetTokenID, "get_token_id"},
|
||||
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Whisper{}); err != nil {
|
||||
|
||||
@@ -10,8 +10,8 @@ CURDIR=$(dirname "$(realpath $0)")
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avrf $CURDIR/whisper $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
cp -avf $CURDIR/whisper $CURDIR/libgowhisper.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
@@ -42,11 +42,13 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
base.SingleThread
|
||||
whisper whisper.Model
|
||||
}
|
||||
|
||||
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
w, err := whisper.New(opts.ModelFile)
|
||||
sd.whisper = w
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
// Process samples
|
||||
context, err := sd.whisper.NewContext()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
|
||||
}
|
||||
|
||||
context.SetThreads(uint(opts.Threads))
|
||||
|
||||
if opts.Language != "" {
|
||||
context.SetLanguage(opts.Language)
|
||||
} else {
|
||||
context.SetLanguage("auto")
|
||||
}
|
||||
|
||||
if opts.Translate {
|
||||
context.SetTranslate(true)
|
||||
}
|
||||
|
||||
if err := context.Process(data, nil, nil, nil); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for {
|
||||
s, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var tokens []int32
|
||||
for _, t := range s.Tokens {
|
||||
tokens = append(tokens, int32(t.Id))
|
||||
}
|
||||
|
||||
segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += s.Text
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: text,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@@ -45,6 +45,7 @@
|
||||
default: "cpu-whisper"
|
||||
nvidia: "cuda12-whisper"
|
||||
intel: "intel-sycl-f16-whisper"
|
||||
metal: "metal-whisper"
|
||||
amd: "rocm-whisper"
|
||||
vulkan: "vulkan-whisper"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisper"
|
||||
@@ -71,7 +72,7 @@
|
||||
# amd: "rocm-stablediffusion-ggml"
|
||||
vulkan: "vulkan-stablediffusion-ggml"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml"
|
||||
# metal: "metal-stablediffusion-ggml"
|
||||
metal: "metal-stablediffusion-ggml"
|
||||
# darwin-x86: "darwin-x86-stablediffusion-ggml"
|
||||
- &rfdetr
|
||||
name: "rfdetr"
|
||||
@@ -147,7 +148,7 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm"
|
||||
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
|
||||
urls:
|
||||
- https://github.com/ml-explore/mlx-vlm
|
||||
- https://github.com/Blaizzy/mlx-vlm
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm
|
||||
license: MIT
|
||||
@@ -159,6 +160,23 @@
|
||||
- vision-language
|
||||
- LLM
|
||||
- MLX
|
||||
- &mlx-audio
|
||||
name: "mlx-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio"
|
||||
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
|
||||
urls:
|
||||
- https://github.com/Blaizzy/mlx-audio
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-mlx-audio
|
||||
license: MIT
|
||||
description: |
|
||||
Run Audio Models with MLX
|
||||
tags:
|
||||
- audio-to-text
|
||||
- audio-generation
|
||||
- text-to-audio
|
||||
- LLM
|
||||
- MLX
|
||||
- &rerankers
|
||||
name: "rerankers"
|
||||
alias: "rerankers"
|
||||
@@ -183,8 +201,6 @@
|
||||
nvidia: "cuda12-transformers"
|
||||
intel: "intel-transformers"
|
||||
amd: "rocm-transformers"
|
||||
metal: "metal-transformers"
|
||||
default: "cpu-transformers"
|
||||
- &diffusers
|
||||
name: "diffusers"
|
||||
icon: https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg
|
||||
@@ -254,6 +270,7 @@
|
||||
nvidia: "cuda12-kokoro"
|
||||
intel: "intel-kokoro"
|
||||
amd: "rocm-kokoro"
|
||||
nvidia-l4t: "nvidia-l4t-kokoro"
|
||||
- &coqui
|
||||
urls:
|
||||
- https://github.com/idiap/coqui-ai-TTS
|
||||
@@ -334,6 +351,9 @@
|
||||
alias: "chatterbox"
|
||||
capabilities:
|
||||
nvidia: "cuda12-chatterbox"
|
||||
metal: "metal-chatterbox"
|
||||
default: "cpu-chatterbox"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||
- &piper
|
||||
name: "piper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
|
||||
@@ -417,6 +437,11 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
|
||||
- !!merge <<: *mlx-audio
|
||||
name: "mlx-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
|
||||
- !!merge <<: *kitten-tts
|
||||
name: "kitten-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
|
||||
@@ -559,6 +584,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "metal-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "metal-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "cpu-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisper"
|
||||
@@ -645,6 +680,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-stablediffusion-ggml"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "metal-stablediffusion-ggml"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-stablediffusion-ggml"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "metal-stablediffusion-ggml-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-stablediffusion-ggml"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "vulkan-stablediffusion-ggml"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-stablediffusion-ggml"
|
||||
@@ -853,28 +898,6 @@
|
||||
nvidia: "cuda12-transformers-development"
|
||||
intel: "intel-transformers-development"
|
||||
amd: "rocm-transformers-development"
|
||||
default: "cpu-transformers-development"
|
||||
metal: "metal-transformers-development"
|
||||
- !!merge <<: *transformers
|
||||
name: "cpu-transformers"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-transformers"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "cpu-transformers-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-transformers"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "metal-transformers"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-transformers"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "metal-transformers-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-transformers"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "cuda12-transformers"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-transformers"
|
||||
@@ -1028,6 +1051,7 @@
|
||||
nvidia: "cuda12-kokoro-development"
|
||||
intel: "intel-kokoro-development"
|
||||
amd: "rocm-kokoro-development"
|
||||
nvidia-l4t: "nvidia-l4t-kokoro-development"
|
||||
- !!merge <<: *kokoro
|
||||
name: "cuda11-kokoro-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-11-kokoro"
|
||||
@@ -1053,6 +1077,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "nvidia-l4t-kokoro"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-l4t-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "nvidia-l4t-kokoro-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-l4t-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "cuda11-kokoro"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-11-kokoro"
|
||||
@@ -1204,6 +1238,39 @@
|
||||
name: "chatterbox-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-chatterbox-development"
|
||||
metal: "metal-chatterbox-development"
|
||||
default: "cpu-chatterbox-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||
- !!merge <<: *chatterbox
|
||||
name: "cpu-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "cpu-chatterbox-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "nvidia-l4t-arm64-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "nvidia-l4t-arm64-chatterbox-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-l4t-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "metal-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "metal-chatterbox-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "cuda12-chatterbox-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"
|
||||
|
||||
@@ -1,38 +1,190 @@
|
||||
# Common commands about conda environment
|
||||
# Python Backends for LocalAI
|
||||
|
||||
## Create a new empty conda environment
|
||||
This directory contains Python-based AI backends for LocalAI, providing support for various AI models and hardware acceleration targets.
|
||||
|
||||
```
|
||||
conda create --name <env-name> python=<your version> -y
|
||||
## Overview
|
||||
|
||||
conda create --name autogptq python=3.11 -y
|
||||
The Python backends use a unified build system based on `libbackend.sh` that provides:
|
||||
- **Automatic virtual environment management** with support for both `uv` and `pip`
|
||||
- **Hardware-specific dependency installation** (CPU, CUDA, Intel, MLX, etc.)
|
||||
- **Portable Python support** for standalone deployments
|
||||
- **Consistent backend execution** across different environments
|
||||
|
||||
## Available Backends
|
||||
|
||||
### Core AI Models
|
||||
- **transformers** - Hugging Face Transformers framework (PyTorch-based)
|
||||
- **vllm** - High-performance LLM inference engine
|
||||
- **mlx** - Apple Silicon optimized ML framework
|
||||
- **exllama2** - ExLlama2 quantized models
|
||||
|
||||
### Audio & Speech
|
||||
- **bark** - Text-to-speech synthesis
|
||||
- **coqui** - Coqui TTS models
|
||||
- **faster-whisper** - Fast Whisper speech recognition
|
||||
- **kitten-tts** - Lightweight TTS
|
||||
- **mlx-audio** - Apple Silicon audio processing
|
||||
- **chatterbox** - TTS model
|
||||
- **kokoro** - TTS models
|
||||
|
||||
### Computer Vision
|
||||
- **diffusers** - Stable Diffusion and image generation
|
||||
- **mlx-vlm** - Vision-language models for Apple Silicon
|
||||
- **rfdetr** - Object detection models
|
||||
|
||||
### Specialized
|
||||
|
||||
- **rerankers** - Text reranking models
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.10+ (default: 3.10.18)
|
||||
- `uv` package manager (recommended) or `pip`
|
||||
- Appropriate hardware drivers for your target (CUDA, Intel, etc.)
|
||||
|
||||
### Installation
|
||||
|
||||
Each backend can be installed individually:
|
||||
|
||||
```bash
|
||||
# Navigate to a specific backend
|
||||
cd backend/python/transformers
|
||||
|
||||
# Install dependencies
|
||||
make transformers
|
||||
# or
|
||||
bash install.sh
|
||||
|
||||
# Run the backend
|
||||
make run
|
||||
# or
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## To activate the environment
|
||||
### Using the Unified Build System
|
||||
|
||||
As of conda 4.4
|
||||
```
|
||||
conda activate autogptq
|
||||
The `libbackend.sh` script provides consistent commands across all backends:
|
||||
|
||||
```bash
|
||||
# Source the library in your backend script
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
|
||||
# Install requirements (automatically handles hardware detection)
|
||||
installRequirements
|
||||
|
||||
# Start the backend server
|
||||
startBackend $@
|
||||
|
||||
# Run tests
|
||||
runUnittests
|
||||
```
|
||||
|
||||
The conda version older than 4.4
|
||||
## Hardware Targets
|
||||
|
||||
```
|
||||
source activate autogptq
|
||||
The build system automatically detects and configures for different hardware:
|
||||
|
||||
- **CPU** - Standard CPU-only builds
|
||||
- **CUDA** - NVIDIA GPU acceleration (supports CUDA 11/12)
|
||||
- **Intel** - Intel XPU/GPU optimization
|
||||
- **MLX** - Apple Silicon (M1/M2/M3) optimization
|
||||
- **HIP** - AMD GPU acceleration
|
||||
|
||||
### Target-Specific Requirements
|
||||
|
||||
Backends can specify hardware-specific dependencies:
|
||||
- `requirements.txt` - Base requirements
|
||||
- `requirements-cpu.txt` - CPU-specific packages
|
||||
- `requirements-cublas11.txt` - CUDA 11 packages
|
||||
- `requirements-cublas12.txt` - CUDA 12 packages
|
||||
- `requirements-intel.txt` - Intel-optimized packages
|
||||
- `requirements-mps.txt` - Apple Silicon packages
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `PYTHON_VERSION` - Python version (default: 3.10)
|
||||
- `PYTHON_PATCH` - Python patch version (default: 18)
|
||||
- `BUILD_TYPE` - Force specific build target
|
||||
- `USE_PIP` - Use pip instead of uv (default: false)
|
||||
- `PORTABLE_PYTHON` - Enable portable Python builds
|
||||
- `LIMIT_TARGETS` - Restrict backend to specific targets
|
||||
|
||||
### Example: CUDA 12 Only Backend
|
||||
|
||||
```bash
|
||||
# In your backend script
|
||||
LIMIT_TARGETS="cublas12"
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
```
|
||||
|
||||
## Install the packages to your environment
|
||||
### Example: Intel-Optimized Backend
|
||||
|
||||
Sometimes you need to install the packages from the conda-forge channel
|
||||
|
||||
By using `conda`
|
||||
```
|
||||
conda install <your-package-name>
|
||||
|
||||
conda install -c conda-forge <your package-name>
|
||||
```bash
|
||||
# In your backend script
|
||||
LIMIT_TARGETS="intel"
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
```
|
||||
|
||||
Or by using `pip`
|
||||
## Development
|
||||
|
||||
### Adding a New Backend
|
||||
|
||||
1. Create a new directory in `backend/python/`
|
||||
2. Copy the template structure from `common/template/`
|
||||
3. Implement your `backend.py` with the required gRPC interface
|
||||
4. Add appropriate requirements files for your target hardware
|
||||
5. Use `libbackend.sh` for consistent build and execution
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Run backend tests
|
||||
make test
|
||||
# or
|
||||
bash test.sh
|
||||
```
|
||||
pip install <your-package-name>
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
make <backend-name>
|
||||
|
||||
# Clean build artifacts
|
||||
make clean
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
Each backend follows a consistent structure:
|
||||
```
|
||||
backend-name/
|
||||
├── backend.py # Main backend implementation
|
||||
├── requirements.txt # Base dependencies
|
||||
├── requirements-*.txt # Hardware-specific dependencies
|
||||
├── install.sh # Installation script
|
||||
├── run.sh # Execution script
|
||||
├── test.sh # Test script
|
||||
├── Makefile # Build targets
|
||||
└── test.py # Unit tests
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Missing dependencies**: Ensure all requirements files are properly configured
|
||||
2. **Hardware detection**: Check that `BUILD_TYPE` matches your system
|
||||
3. **Python version**: Verify Python 3.10+ is available
|
||||
4. **Virtual environment**: Use `ensureVenv` to create/activate environments
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new backends or modifying existing ones:
|
||||
1. Follow the established directory structure
|
||||
2. Use `libbackend.sh` for consistent behavior
|
||||
3. Include appropriate requirements files for all target hardware
|
||||
4. Add comprehensive tests
|
||||
5. Update this README if adding new backend types
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
bark==0.1.5
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
certifi
|
||||
@@ -14,9 +14,23 @@ import backend_pb2_grpc
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -47,6 +61,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the images
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":")
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
self.AudioPath = None
|
||||
|
||||
if os.path.isabs(request.AudioPath):
|
||||
@@ -56,10 +92,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFileBase = os.path.dirname(request.ModelFile)
|
||||
# modify LoraAdapter to be relative to modelFileBase
|
||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||
if "multilingual" in self.options:
|
||||
# remove key from options
|
||||
del self.options["multilingual"]
|
||||
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
else:
|
||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
@@ -68,12 +108,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Generate audio using ChatterboxTTS
|
||||
kwargs = {}
|
||||
|
||||
if "language" in self.options:
|
||||
kwargs["language_id"] = self.options["language"]
|
||||
if self.AudioPath is not None:
|
||||
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
|
||||
else:
|
||||
wav = self.model.generate(request.text)
|
||||
|
||||
kwargs["audio_prompt_path"] = self.AudioPath
|
||||
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Generate audio using ChatterboxTTS
|
||||
wav = self.model.generate(request.text, **kwargs)
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
|
||||
|
||||
@@ -15,5 +15,6 @@ fi
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||
|
||||
installRequirements
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
accelerate
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
#chatterbox-tts==0.1.4
|
||||
@@ -2,5 +2,6 @@
|
||||
torch==2.6.0+cu118
|
||||
torchaudio==2.6.0+cu118
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -1,5 +1,6 @@
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.6.0+rocm6.1
|
||||
torchaudio==2.6.0+rocm6.1
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
accelerate
|
||||
setuptools
|
||||
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -286,7 +286,8 @@ _makeVenvPortable() {
|
||||
function ensureVenv() {
|
||||
local interpreter=""
|
||||
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -e "$(_portable_python)" ]; then
|
||||
echo "Using portable Python"
|
||||
ensurePortablePython
|
||||
interpreter="$(_portable_python)"
|
||||
else
|
||||
@@ -384,6 +385,11 @@ function installRequirements() {
|
||||
requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt")
|
||||
fi
|
||||
|
||||
# This is needed to build wheels that e.g. depends on Python.h
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
|
||||
for reqFile in ${requirementFiles[@]}; do
|
||||
if [ -f "${reqFile}" ]; then
|
||||
echo "starting requirements install for ${reqFile}"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
grpcio-tools
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
@@ -18,7 +18,7 @@ import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
|
||||
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
@@ -66,19 +66,21 @@ from diffusers.schedulers import (
|
||||
)
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||
# Credits to https://github.com/neggles
|
||||
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
||||
@@ -187,6 +189,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
||||
@@ -334,6 +338,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
torch_dtype=torch.bfloat16)
|
||||
self.pipe.vae.to(torch.bfloat16)
|
||||
self.pipe.text_encoder.to(torch.bfloat16)
|
||||
elif request.PipelineType == "WanPipeline":
|
||||
# WAN2.2 pipeline requires special VAE handling
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
|
||||
elif request.PipelineType == "WanImageToVideoPipeline":
|
||||
# WAN2.2 image-to-video pipeline
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.img2vid = True # WAN2.2 image-to-video pipeline
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
self.clip_skip = request.CLIPSkip
|
||||
@@ -475,11 +505,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
if request.src != "" and not self.controlnet and not self.img2vid:
|
||||
image = Image.open(request.src)
|
||||
# Handle image source: prioritize RefImages over request.src
|
||||
image_src = None
|
||||
if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
|
||||
# Use the first reference image if available
|
||||
image_src = request.ref_images[0]
|
||||
print(f"Using reference image: {image_src}", file=sys.stderr)
|
||||
elif request.src != "":
|
||||
# Fall back to request.src if no ref_images
|
||||
image_src = request.src
|
||||
print(f"Using source image: {image_src}", file=sys.stderr)
|
||||
else:
|
||||
print("No image source provided", file=sys.stderr)
|
||||
|
||||
if image_src and not self.controlnet and not self.img2vid:
|
||||
image = Image.open(image_src)
|
||||
options["image"] = image
|
||||
elif self.controlnet and request.src:
|
||||
pose_image = load_image(request.src)
|
||||
elif self.controlnet and image_src:
|
||||
pose_image = load_image(image_src)
|
||||
options["image"] = pose_image
|
||||
|
||||
if CLIPSKIP and self.clip_skip != 0:
|
||||
@@ -521,7 +564,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
if self.img2vid:
|
||||
# Load the conditioning image
|
||||
image = load_image(request.src)
|
||||
if image_src:
|
||||
image = load_image(image_src)
|
||||
else:
|
||||
# Fallback to request.src for img2vid if no ref_images
|
||||
image = load_image(request.src)
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(request.seed)
|
||||
@@ -558,6 +605,96 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Media generated", success=True)
|
||||
|
||||
def GenerateVideo(self, request, context):
|
||||
try:
|
||||
prompt = request.prompt
|
||||
if not prompt:
|
||||
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
|
||||
|
||||
# Set default values from request or use defaults
|
||||
num_frames = request.num_frames if request.num_frames > 0 else 81
|
||||
fps = request.fps if request.fps > 0 else 16
|
||||
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
||||
num_inference_steps = request.step if request.step > 0 else 40
|
||||
|
||||
# Prepare generation parameters
|
||||
kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
|
||||
"height": request.height if request.height > 0 else 720,
|
||||
"width": request.width if request.width > 0 else 1280,
|
||||
"num_frames": num_frames,
|
||||
"guidance_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
|
||||
# Add custom options from self.options (including guidance_scale_2 if specified)
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Set seed if provided
|
||||
if request.seed > 0:
|
||||
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
|
||||
|
||||
# Handle start and end images for video generation
|
||||
if request.start_image:
|
||||
kwargs["start_image"] = load_image(request.start_image)
|
||||
if request.end_image:
|
||||
kwargs["end_image"] = load_image(request.end_image)
|
||||
|
||||
print(f"Generating video with {kwargs=}", file=sys.stderr)
|
||||
|
||||
# Generate video frames based on pipeline type
|
||||
if self.PipelineType == "WanPipeline":
|
||||
# WAN2.2 text-to-video generation
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0] # WAN2.2 returns frames in this format
|
||||
elif self.PipelineType == "WanImageToVideoPipeline":
|
||||
# WAN2.2 image-to-video generation
|
||||
if request.start_image:
|
||||
# Load and resize the input image according to WAN2.2 requirements
|
||||
image = load_image(request.start_image)
|
||||
# Use request dimensions or defaults, but respect WAN2.2 constraints
|
||||
request_height = request.height if request.height > 0 else 480
|
||||
request_width = request.width if request.width > 0 else 832
|
||||
max_area = request_height * request_width
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
|
||||
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
|
||||
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
|
||||
image = image.resize((width, height))
|
||||
kwargs["image"] = image
|
||||
kwargs["height"] = height
|
||||
kwargs["width"] = width
|
||||
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
elif self.img2vid:
|
||||
# Generic image-to-video generation
|
||||
if request.start_image:
|
||||
image = load_image(request.start_image)
|
||||
image = image.resize((request.width if request.width > 0 else 1024,
|
||||
request.height if request.height > 0 else 576))
|
||||
kwargs["image"] = image
|
||||
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
elif self.txt2vid:
|
||||
# Generic text-to-video generation
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
|
||||
|
||||
# Export video
|
||||
export_to_video(frames, request.dst, fps=fps)
|
||||
|
||||
return backend_pb2.Result(message="Video generated successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error generating video: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
|
||||
@@ -8,4 +8,5 @@ compel
|
||||
peft
|
||||
sentencepiece
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -1,11 +1,12 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.7.1+cu118
|
||||
torchvision==0.22.1+cu118
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
torchvision==0.22.1
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -1,10 +1,12 @@
|
||||
torch==2.7.1
|
||||
torchvision==0.22.1
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
torchvision
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
torch
|
||||
ftfy
|
||||
optimum-quanto
|
||||
|
||||
@@ -8,4 +8,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -12,4 +12,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -8,4 +8,5 @@ peft
|
||||
optimum-quanto
|
||||
numpy<2
|
||||
sentencepiece
|
||||
torchvision
|
||||
torchvision
|
||||
ftfy
|
||||
@@ -7,4 +7,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -1,5 +1,5 @@
|
||||
setuptools
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
pillow
|
||||
protobuf
|
||||
certifi
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
certifi
|
||||
wheel
|
||||
|
||||
7
backend/python/kokoro/requirements-l4t.txt
Normal file
7
backend/python/kokoro/requirements-l4t.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
23
backend/python/mlx-audio/Makefile
Normal file
23
backend/python/mlx-audio/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: mlx-audio
|
||||
mlx-audio:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: mlx-audio
|
||||
@echo "Running mlx-audio..."
|
||||
bash run.sh
|
||||
@echo "mlx run."
|
||||
|
||||
.PHONY: test
|
||||
test: mlx-audio
|
||||
@echo "Testing mlx-audio..."
|
||||
bash test.sh
|
||||
@echo "mlx tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
465
backend/python/mlx-audio/backend.py
Normal file
465
backend/python/mlx-audio/backend.py
Normal file
@@ -0,0 +1,465 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
from typing import List
|
||||
import time
|
||||
import tempfile
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
from mlx_audio.tts.utils import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
|
||||
"""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
|
||||
Args:
|
||||
request: The health check request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The health check reply.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
"""
|
||||
Loads a TTS model using MLX-Audio.
|
||||
|
||||
Args:
|
||||
request: The load model request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The load model result.
|
||||
"""
|
||||
try:
|
||||
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the kokoro backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Load the model using MLX-Audio's load_model function
|
||||
try:
|
||||
self.tts_model = load_model(request.Model)
|
||||
self.model_path = request.Model
|
||||
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
|
||||
except Exception as model_err:
|
||||
print(f"Error loading TTS model: {model_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
|
||||
|
||||
print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""
|
||||
Generates TTS audio from text using MLX-Audio.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest object containing text, model, destination, voice, and language.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure.
|
||||
"""
|
||||
try:
|
||||
# Check if model is loaded
|
||||
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
||||
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
|
||||
|
||||
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
speed_value = self._handle_speed_parameter(request, self.model_path)
|
||||
|
||||
# Map language names to codes if needed
|
||||
lang_code = self._map_language_code(request.language, request.voice)
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_params = {
|
||||
"text": request.text,
|
||||
"speed": speed_value,
|
||||
"verbose": False,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if request.voice and request.voice.strip():
|
||||
gen_params["voice"] = request.voice
|
||||
|
||||
# Check if model supports language codes (primarily Kokoro)
|
||||
if "kokoro" in self.model_path.lower():
|
||||
gen_params["lang_code"] = lang_code
|
||||
|
||||
# Add pitch and gender for Spark models
|
||||
if "spark" in self.model_path.lower():
|
||||
gen_params["pitch"] = 1.0 # Default to moderate
|
||||
gen_params["gender"] = "female" # Default to female
|
||||
|
||||
print(f"Generation parameters: {gen_params}", file=sys.stderr)
|
||||
|
||||
# Generate audio using the loaded model
|
||||
try:
|
||||
results = self.tts_model.generate(**gen_params)
|
||||
except Exception as gen_err:
|
||||
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
|
||||
|
||||
# Process the generated audio segments
|
||||
audio_arrays = []
|
||||
for segment in results:
|
||||
audio_arrays.append(segment.audio)
|
||||
|
||||
# If no segments, return error
|
||||
if not audio_arrays:
|
||||
print("No audio segments generated", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="No audio generated")
|
||||
|
||||
# Concatenate all segments
|
||||
cat_audio = np.concatenate(audio_arrays, axis=0)
|
||||
|
||||
# Generate output filename and path
|
||||
if request.dst:
|
||||
output_path = request.dst
|
||||
else:
|
||||
unique_id = str(uuid.uuid4())
|
||||
filename = f"tts_{unique_id}.wav"
|
||||
output_path = filename
|
||||
|
||||
# Write the audio as a WAV
|
||||
try:
|
||||
sf.write(output_path, cat_audio, 24000)
|
||||
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
|
||||
|
||||
# Verify the file exists and has content
|
||||
if not os.path.exists(output_path):
|
||||
print(f"File was not created at {output_path}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="Failed to create audio file")
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
if file_size == 0:
|
||||
print("File was created but is empty", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="Generated audio file is empty")
|
||||
|
||||
print(f"Audio file size: {file_size} bytes", file=sys.stderr)
|
||||
|
||||
except Exception as write_err:
|
||||
print(f"Error writing audio file: {write_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates TTS audio based on the given prompt using MLX-Audio TTS.
|
||||
This is a fallback method for compatibility with the Predict endpoint.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
try:
|
||||
# Check if model is loaded
|
||||
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("TTS model not loaded. Please call LoadModel first.")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# For TTS, we expect the prompt to contain the text to synthesize
|
||||
if not request.Prompt:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("Prompt is required for TTS generation")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
speed_value = self._handle_speed_parameter(request, self.model_path)
|
||||
|
||||
# Map language names to codes if needed
|
||||
lang_code = self._map_language_code(None, None) # Use defaults for Predict
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_params = {
|
||||
"text": request.Prompt,
|
||||
"speed": speed_value,
|
||||
"verbose": False,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if hasattr(self, 'options') and 'voice' in self.options:
|
||||
gen_params["voice"] = self.options['voice']
|
||||
|
||||
# Check if model supports language codes (primarily Kokoro)
|
||||
if "kokoro" in self.model_path.lower():
|
||||
gen_params["lang_code"] = lang_code
|
||||
|
||||
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
|
||||
|
||||
# Generate audio using the loaded model
|
||||
try:
|
||||
results = self.tts_model.generate(**gen_params)
|
||||
except Exception as gen_err:
|
||||
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"TTS generation failed: {gen_err}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# Process the generated audio segments
|
||||
audio_arrays = []
|
||||
for segment in results:
|
||||
audio_arrays.append(segment.audio)
|
||||
|
||||
# If no segments, return error
|
||||
if not audio_arrays:
|
||||
print("No audio segments generated", file=sys.stderr)
|
||||
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
|
||||
|
||||
# Concatenate all segments
|
||||
cat_audio = np.concatenate(audio_arrays, axis=0)
|
||||
duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
|
||||
|
||||
# Return success message with audio information
|
||||
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"TTS generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
def _handle_speed_parameter(self, request, model_path):
|
||||
"""
|
||||
Handle speed parameter based on model type.
|
||||
|
||||
Args:
|
||||
request: The TTSRequest object.
|
||||
model_path: The model path to determine model type.
|
||||
|
||||
Returns:
|
||||
float: The processed speed value.
|
||||
"""
|
||||
# Get speed from options if available
|
||||
speed = 1.0
|
||||
if hasattr(self, 'options') and 'speed' in self.options:
|
||||
speed = self.options['speed']
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
if "spark" in model_path.lower():
|
||||
# Spark actually expects float values that map to speed descriptions
|
||||
speed_map = {
|
||||
"very_low": 0.0,
|
||||
"low": 0.5,
|
||||
"moderate": 1.0,
|
||||
"high": 1.5,
|
||||
"very_high": 2.0,
|
||||
}
|
||||
if isinstance(speed, str) and speed in speed_map:
|
||||
speed_value = speed_map[speed]
|
||||
else:
|
||||
# Try to use as float, default to 1.0 (moderate) if invalid
|
||||
try:
|
||||
speed_value = float(speed)
|
||||
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
|
||||
speed_value = 1.0 # Default to moderate
|
||||
except:
|
||||
speed_value = 1.0 # Default to moderate
|
||||
else:
|
||||
# Other models use float speed values
|
||||
try:
|
||||
speed_value = float(speed)
|
||||
if speed_value < 0.5 or speed_value > 2.0:
|
||||
speed_value = 1.0 # Default to 1.0 if out of range
|
||||
except ValueError:
|
||||
speed_value = 1.0 # Default to 1.0 if invalid
|
||||
|
||||
return speed_value
|
||||
|
||||
def _map_language_code(self, language, voice):
|
||||
"""
|
||||
Map language names to codes if needed.
|
||||
|
||||
Args:
|
||||
language: The language parameter from the request.
|
||||
voice: The voice parameter from the request.
|
||||
|
||||
Returns:
|
||||
str: The language code.
|
||||
"""
|
||||
if not language:
|
||||
# Default to voice[0] if not found
|
||||
return voice[0] if voice else "a"
|
||||
|
||||
# Map language names to codes if needed
|
||||
language_map = {
|
||||
"american_english": "a",
|
||||
"british_english": "b",
|
||||
"spanish": "e",
|
||||
"french": "f",
|
||||
"hindi": "h",
|
||||
"italian": "i",
|
||||
"portuguese": "p",
|
||||
"japanese": "j",
|
||||
"mandarin_chinese": "z",
|
||||
# Also accept direct language codes
|
||||
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
|
||||
}
|
||||
|
||||
return language_map.get(language.lower(), language)
|
||||
|
||||
def _build_generation_params(self, request, default_speed=1.0):
|
||||
"""
|
||||
Build generation parameters from request attributes and options for MLX-Audio TTS.
|
||||
|
||||
Args:
|
||||
request: The gRPC request.
|
||||
default_speed: Default speed if not specified.
|
||||
|
||||
Returns:
|
||||
dict: Generation parameters for MLX-Audio
|
||||
"""
|
||||
# Initialize generation parameters for MLX-Audio TTS
|
||||
generation_params = {
|
||||
'speed': default_speed,
|
||||
'voice': 'af_heart', # Default voice
|
||||
'lang_code': 'a', # Default language code
|
||||
}
|
||||
|
||||
# Extract parameters from request attributes
|
||||
if hasattr(request, 'Temperature') and request.Temperature > 0:
|
||||
# Temperature could be mapped to speed variation
|
||||
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
|
||||
|
||||
# Override with options if available
|
||||
if hasattr(self, 'options'):
|
||||
# Speed from options
|
||||
if 'speed' in self.options:
|
||||
generation_params['speed'] = self.options['speed']
|
||||
|
||||
# Voice from options
|
||||
if 'voice' in self.options:
|
||||
generation_params['voice'] = self.options['voice']
|
||||
|
||||
# Language code from options
|
||||
if 'lang_code' in self.options:
|
||||
generation_params['lang_code'] = self.options['lang_code']
|
||||
|
||||
# Model-specific parameters
|
||||
param_option_mapping = {
|
||||
'temp': 'speed',
|
||||
'temperature': 'speed',
|
||||
'top_p': 'speed', # Map top_p to speed variation
|
||||
}
|
||||
|
||||
for option_key, param_key in param_option_mapping.items():
|
||||
if option_key in self.options:
|
||||
if param_key == 'speed':
|
||||
# Ensure speed is within reasonable bounds
|
||||
speed_val = float(self.options[option_key])
|
||||
if 0.5 <= speed_val <= 2.0:
|
||||
generation_params[param_key] = speed_val
|
||||
|
||||
return generation_params
|
||||
|
||||
async def serve(address):
|
||||
# Start asyncio gRPC server
|
||||
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
server.add_insecure_port(address)
|
||||
|
||||
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(
|
||||
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||
)
|
||||
|
||||
# Start the server
|
||||
await server.start()
|
||||
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
|
||||
# Wait for the server to be terminated
|
||||
await server.wait_for_termination()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(serve(args.addr))
|
||||
14
backend/python/mlx-audio/install.sh
Executable file
14
backend/python/mlx-audio/install.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
USE_PIP=true
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
1
backend/python/mlx-audio/requirements-mps.txt
Normal file
1
backend/python/mlx-audio/requirements-mps.txt
Normal file
@@ -0,0 +1 @@
|
||||
git+https://github.com/Blaizzy/mlx-audio
|
||||
7
backend/python/mlx-audio/requirements.txt
Normal file
7
backend/python/mlx-audio/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
mlx-audio
|
||||
soundfile
|
||||
numpy
|
||||
11
backend/python/mlx-audio/run.sh
Executable file
11
backend/python/mlx-audio/run.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
142
backend/python/mlx-audio/test.py
Normal file
142
backend/python/mlx-audio/test.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
|
||||
This class contains methods to test the startup and shutdown of the gRPC service.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the TTS model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_generation(self):
|
||||
"""
|
||||
This method tests if TTS audio is generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test TTS generation
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hello, this is a test of the MLX-Audio TTS system.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_heart",
|
||||
language="a"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_with_options(self):
|
||||
"""
|
||||
This method tests if TTS works with various options and parameters
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(
|
||||
Model="mlx-community/Kokoro-82M-4bit",
|
||||
Options=["voice:af_soft", "speed:1.2", "lang_code:b"]
|
||||
))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test TTS generation with different voice and language
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hello, this is a test with British English accent.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_soft",
|
||||
language="b"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS with options service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_tts_multilingual(self):
|
||||
"""
|
||||
This method tests if TTS works with different languages
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test Spanish TTS
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hola, esto es una prueba del sistema TTS MLX-Audio.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_heart",
|
||||
language="e"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Multilingual TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
12
backend/python/mlx-audio/test.sh
Executable file
12
backend/python/mlx-audio/test.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -21,6 +21,21 @@ import io
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -32,22 +47,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||
"""
|
||||
|
||||
def _is_float(self, s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _is_int(self, s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
@@ -87,10 +86,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if self._is_float(value):
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif self._is_int(value):
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
@@ -24,28 +24,27 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||
"""
|
||||
|
||||
def _is_float(self, s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _is_int(self, s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
@@ -86,9 +85,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if self._is_float(value):
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif self._is_int(value):
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
certifi
|
||||
@@ -1,4 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf==6.32.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.75.1
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
16
cmd/launcher/icon.go
Normal file
16
cmd/launcher/icon.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
)
|
||||
|
||||
//go:embed logo.png
|
||||
var logoData []byte
|
||||
|
||||
// resourceIconPng is the LocalAI logo icon
|
||||
var resourceIconPng = &fyne.StaticResource{
|
||||
StaticName: "logo.png",
|
||||
StaticContent: logoData,
|
||||
}
|
||||
866
cmd/launcher/internal/launcher.go
Normal file
866
cmd/launcher/internal/launcher.go
Normal file
@@ -0,0 +1,866 @@
|
||||
package launcher
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
)
|
||||
|
||||
// Config represents the launcher configuration
|
||||
type Config struct {
|
||||
ModelsPath string `json:"models_path"`
|
||||
BackendsPath string `json:"backends_path"`
|
||||
Address string `json:"address"`
|
||||
AutoStart bool `json:"auto_start"`
|
||||
StartOnBoot bool `json:"start_on_boot"`
|
||||
LogLevel string `json:"log_level"`
|
||||
EnvironmentVars map[string]string `json:"environment_vars"`
|
||||
ShowWelcome *bool `json:"show_welcome"`
|
||||
}
|
||||
|
||||
// Launcher represents the main launcher application
|
||||
type Launcher struct {
|
||||
// Core components
|
||||
releaseManager *ReleaseManager
|
||||
config *Config
|
||||
ui *LauncherUI
|
||||
systray *SystrayManager
|
||||
ctx context.Context
|
||||
window fyne.Window
|
||||
app fyne.App
|
||||
|
||||
// Process management
|
||||
localaiCmd *exec.Cmd
|
||||
isRunning bool
|
||||
logBuffer *strings.Builder
|
||||
logMutex sync.RWMutex
|
||||
statusChannel chan string
|
||||
|
||||
// Logging
|
||||
logFile *os.File
|
||||
logPath string
|
||||
|
||||
// UI state
|
||||
lastUpdateCheck time.Time
|
||||
}
|
||||
|
||||
// NewLauncher creates a new launcher instance
|
||||
func NewLauncher(ui *LauncherUI, window fyne.Window, app fyne.App) *Launcher {
|
||||
return &Launcher{
|
||||
releaseManager: NewReleaseManager(),
|
||||
config: &Config{},
|
||||
logBuffer: &strings.Builder{},
|
||||
statusChannel: make(chan string, 100),
|
||||
ctx: context.Background(),
|
||||
ui: ui,
|
||||
window: window,
|
||||
app: app,
|
||||
}
|
||||
}
|
||||
|
||||
// setupLogging sets up log file for LocalAI process output
|
||||
func (l *Launcher) setupLogging() error {
|
||||
// Create logs directory in data folder
|
||||
dataPath := l.GetDataPath()
|
||||
logsDir := filepath.Join(dataPath, "logs")
|
||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Create log file with timestamp
|
||||
timestamp := time.Now().Format("2006-01-02_15-04-05")
|
||||
l.logPath = filepath.Join(logsDir, fmt.Sprintf("localai_%s.log", timestamp))
|
||||
|
||||
logFile, err := os.Create(l.logPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
|
||||
l.logFile = logFile
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize sets up the launcher
|
||||
func (l *Launcher) Initialize() error {
|
||||
if l.app == nil {
|
||||
return fmt.Errorf("app is nil")
|
||||
}
|
||||
log.Printf("Initializing launcher...")
|
||||
|
||||
// Setup logging
|
||||
if err := l.setupLogging(); err != nil {
|
||||
return fmt.Errorf("failed to setup logging: %w", err)
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
log.Printf("Loading configuration...")
|
||||
if err := l.loadConfig(); err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
log.Printf("Configuration loaded, current state: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
|
||||
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
|
||||
|
||||
// Clean up any partial downloads
|
||||
log.Printf("Cleaning up partial downloads...")
|
||||
if err := l.releaseManager.CleanupPartialDownloads(); err != nil {
|
||||
log.Printf("Warning: failed to cleanup partial downloads: %v", err)
|
||||
}
|
||||
|
||||
if l.config.StartOnBoot {
|
||||
l.StartLocalAI()
|
||||
}
|
||||
// Set default paths if not configured (only if not already loaded from config)
|
||||
if l.config.ModelsPath == "" {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
l.config.ModelsPath = filepath.Join(homeDir, ".localai", "models")
|
||||
log.Printf("Setting default ModelsPath: %s", l.config.ModelsPath)
|
||||
}
|
||||
if l.config.BackendsPath == "" {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
l.config.BackendsPath = filepath.Join(homeDir, ".localai", "backends")
|
||||
log.Printf("Setting default BackendsPath: %s", l.config.BackendsPath)
|
||||
}
|
||||
if l.config.Address == "" {
|
||||
l.config.Address = "127.0.0.1:8080"
|
||||
log.Printf("Setting default Address: %s", l.config.Address)
|
||||
}
|
||||
if l.config.LogLevel == "" {
|
||||
l.config.LogLevel = "info"
|
||||
log.Printf("Setting default LogLevel: %s", l.config.LogLevel)
|
||||
}
|
||||
if l.config.EnvironmentVars == nil {
|
||||
l.config.EnvironmentVars = make(map[string]string)
|
||||
log.Printf("Initializing empty EnvironmentVars map")
|
||||
}
|
||||
|
||||
// Set default welcome window preference
|
||||
if l.config.ShowWelcome == nil {
|
||||
true := true
|
||||
l.config.ShowWelcome = &true
|
||||
log.Printf("Setting default ShowWelcome: true")
|
||||
}
|
||||
|
||||
// Create directories
|
||||
os.MkdirAll(l.config.ModelsPath, 0755)
|
||||
os.MkdirAll(l.config.BackendsPath, 0755)
|
||||
|
||||
// Save the configuration with default values
|
||||
if err := l.saveConfig(); err != nil {
|
||||
log.Printf("Warning: failed to save default configuration: %v", err)
|
||||
}
|
||||
|
||||
// System tray is now handled in main.go using Fyne's built-in approach
|
||||
|
||||
// Check if LocalAI is installed
|
||||
if !l.releaseManager.IsLocalAIInstalled() {
|
||||
log.Printf("No LocalAI installation found")
|
||||
fyne.Do(func() {
|
||||
l.updateStatus("No LocalAI installation found")
|
||||
if l.ui != nil {
|
||||
// Show dialog offering to download LocalAI
|
||||
l.showDownloadLocalAIDialog()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Check for updates periodically
|
||||
go l.periodicUpdateCheck()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartLocalAI starts the LocalAI server
|
||||
func (l *Launcher) StartLocalAI() error {
|
||||
if l.isRunning {
|
||||
return fmt.Errorf("LocalAI is already running")
|
||||
}
|
||||
|
||||
// Verify binary integrity before starting
|
||||
if err := l.releaseManager.VerifyInstalledBinary(); err != nil {
|
||||
// Binary is corrupted, remove it and offer to reinstall
|
||||
binaryPath := l.releaseManager.GetBinaryPath()
|
||||
if removeErr := os.Remove(binaryPath); removeErr != nil {
|
||||
log.Printf("Failed to remove corrupted binary: %v", removeErr)
|
||||
}
|
||||
return fmt.Errorf("LocalAI binary is corrupted: %v. Please reinstall LocalAI", err)
|
||||
}
|
||||
|
||||
binaryPath := l.releaseManager.GetBinaryPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("LocalAI binary not found. Please download a release first")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
args := []string{
|
||||
"run",
|
||||
"--models-path", l.config.ModelsPath,
|
||||
"--backends-path", l.config.BackendsPath,
|
||||
"--address", l.config.Address,
|
||||
"--log-level", l.config.LogLevel,
|
||||
}
|
||||
|
||||
l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...)
|
||||
|
||||
// Apply environment variables
|
||||
if len(l.config.EnvironmentVars) > 0 {
|
||||
env := os.Environ()
|
||||
for key, value := range l.config.EnvironmentVars {
|
||||
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
l.localaiCmd.Env = env
|
||||
}
|
||||
|
||||
// Setup logging
|
||||
stdout, err := l.localaiCmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := l.localaiCmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start the process
|
||||
if err := l.localaiCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start LocalAI: %w", err)
|
||||
}
|
||||
|
||||
l.isRunning = true
|
||||
|
||||
fyne.Do(func() {
|
||||
l.updateStatus("LocalAI is starting...")
|
||||
l.updateRunningState(true)
|
||||
})
|
||||
|
||||
// Start log monitoring
|
||||
go l.monitorLogs(stdout, "STDOUT")
|
||||
go l.monitorLogs(stderr, "STDERR")
|
||||
|
||||
// Monitor process with startup timeout
|
||||
go func() {
|
||||
// Wait for process to start or fail
|
||||
err := l.localaiCmd.Wait()
|
||||
l.isRunning = false
|
||||
fyne.Do(func() {
|
||||
l.updateRunningState(false)
|
||||
if err != nil {
|
||||
l.updateStatus(fmt.Sprintf("LocalAI stopped with error: %v", err))
|
||||
} else {
|
||||
l.updateStatus("LocalAI stopped")
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
// Add startup timeout detection
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second) // Wait 10 seconds for startup
|
||||
if l.isRunning {
|
||||
// Check if process is still alive
|
||||
if l.localaiCmd.Process != nil {
|
||||
if err := l.localaiCmd.Process.Signal(syscall.Signal(0)); err != nil {
|
||||
// Process is dead, mark as not running
|
||||
l.isRunning = false
|
||||
fyne.Do(func() {
|
||||
l.updateRunningState(false)
|
||||
l.updateStatus("LocalAI failed to start properly")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopLocalAI stops the LocalAI server
|
||||
func (l *Launcher) StopLocalAI() error {
|
||||
if !l.isRunning || l.localaiCmd == nil {
|
||||
return fmt.Errorf("LocalAI is not running")
|
||||
}
|
||||
|
||||
// Gracefully terminate the process
|
||||
if err := l.localaiCmd.Process.Signal(os.Interrupt); err != nil {
|
||||
// If graceful termination fails, force kill
|
||||
if killErr := l.localaiCmd.Process.Kill(); killErr != nil {
|
||||
return fmt.Errorf("failed to kill LocalAI process: %w", killErr)
|
||||
}
|
||||
}
|
||||
|
||||
l.isRunning = false
|
||||
fyne.Do(func() {
|
||||
l.updateRunningState(false)
|
||||
l.updateStatus("LocalAI stopped")
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether LocalAI is currently running
|
||||
func (l *Launcher) IsRunning() bool {
|
||||
return l.isRunning
|
||||
}
|
||||
|
||||
// Shutdown performs cleanup when the application is closing
|
||||
func (l *Launcher) Shutdown() error {
|
||||
log.Printf("Launcher shutting down, stopping LocalAI...")
|
||||
|
||||
// Stop LocalAI if it's running
|
||||
if l.isRunning {
|
||||
if err := l.StopLocalAI(); err != nil {
|
||||
log.Printf("Error stopping LocalAI during shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close log file if open
|
||||
if l.logFile != nil {
|
||||
if err := l.logFile.Close(); err != nil {
|
||||
log.Printf("Error closing log file: %v", err)
|
||||
}
|
||||
l.logFile = nil
|
||||
}
|
||||
|
||||
log.Printf("Launcher shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogs returns the current log buffer
|
||||
func (l *Launcher) GetLogs() string {
|
||||
l.logMutex.RLock()
|
||||
defer l.logMutex.RUnlock()
|
||||
return l.logBuffer.String()
|
||||
}
|
||||
|
||||
// GetRecentLogs returns the most recent logs (last 50 lines) for better error display
|
||||
func (l *Launcher) GetRecentLogs() string {
|
||||
l.logMutex.RLock()
|
||||
defer l.logMutex.RUnlock()
|
||||
|
||||
content := l.logBuffer.String()
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
// Get last 50 lines
|
||||
if len(lines) > 50 {
|
||||
lines = lines[len(lines)-50:]
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
func (l *Launcher) GetConfig() *Config {
|
||||
return l.config
|
||||
}
|
||||
|
||||
// SetConfig updates the configuration
|
||||
func (l *Launcher) SetConfig(config *Config) error {
|
||||
l.config = config
|
||||
return l.saveConfig()
|
||||
}
|
||||
|
||||
func (l *Launcher) GetUI() *LauncherUI {
|
||||
return l.ui
|
||||
}
|
||||
|
||||
func (l *Launcher) SetSystray(systray *SystrayManager) {
|
||||
l.systray = systray
|
||||
}
|
||||
|
||||
// GetReleaseManager returns the release manager
|
||||
func (l *Launcher) GetReleaseManager() *ReleaseManager {
|
||||
return l.releaseManager
|
||||
}
|
||||
|
||||
// GetWebUIURL returns the URL for the WebUI
|
||||
func (l *Launcher) GetWebUIURL() string {
|
||||
address := l.config.Address
|
||||
if strings.HasPrefix(address, ":") {
|
||||
address = "localhost" + address
|
||||
}
|
||||
if !strings.HasPrefix(address, "http") {
|
||||
address = "http://" + address
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
// GetDataPath returns the path where LocalAI data and logs are stored
|
||||
func (l *Launcher) GetDataPath() string {
|
||||
// LocalAI typically stores data in the current working directory or a models directory
|
||||
// First check if models path is configured
|
||||
if l.config != nil && l.config.ModelsPath != "" {
|
||||
// Return the parent directory of models path
|
||||
return filepath.Dir(l.config.ModelsPath)
|
||||
}
|
||||
|
||||
// Fallback to home directory LocalAI folder
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return filepath.Join(homeDir, ".localai")
|
||||
}
|
||||
|
||||
// CheckForUpdates checks if there are any available updates
|
||||
func (l *Launcher) CheckForUpdates() (bool, string, error) {
|
||||
log.Printf("CheckForUpdates: checking for available updates...")
|
||||
available, version, err := l.releaseManager.IsUpdateAvailable()
|
||||
if err != nil {
|
||||
log.Printf("CheckForUpdates: error occurred: %v", err)
|
||||
return false, "", err
|
||||
}
|
||||
log.Printf("CheckForUpdates: result - available=%v, version=%s", available, version)
|
||||
l.lastUpdateCheck = time.Now()
|
||||
return available, version, nil
|
||||
}
|
||||
|
||||
// DownloadUpdate downloads the latest version
|
||||
func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
|
||||
return l.releaseManager.DownloadRelease(version, progressCallback)
|
||||
}
|
||||
|
||||
// GetCurrentVersion returns the current installed version
|
||||
func (l *Launcher) GetCurrentVersion() string {
|
||||
return l.releaseManager.GetInstalledVersion()
|
||||
}
|
||||
|
||||
// GetCurrentStatus returns the current status
|
||||
func (l *Launcher) GetCurrentStatus() string {
|
||||
select {
|
||||
case status := <-l.statusChannel:
|
||||
return status
|
||||
default:
|
||||
if l.isRunning {
|
||||
return "LocalAI is running"
|
||||
}
|
||||
return "Ready"
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastStatus returns the last known status without consuming from channel
|
||||
func (l *Launcher) GetLastStatus() string {
|
||||
if l.isRunning {
|
||||
return "LocalAI is running"
|
||||
}
|
||||
|
||||
// Check if LocalAI is installed
|
||||
if !l.releaseManager.IsLocalAIInstalled() {
|
||||
return "LocalAI not installed"
|
||||
}
|
||||
|
||||
return "Ready"
|
||||
}
|
||||
|
||||
func (l *Launcher) githubReleaseNotesURL(version string) (*url.URL, error) {
|
||||
// Construct GitHub release URL
|
||||
releaseURL := fmt.Sprintf("https://github.com/%s/%s/releases/tag/%s",
|
||||
l.releaseManager.GitHubOwner,
|
||||
l.releaseManager.GitHubRepo,
|
||||
version)
|
||||
|
||||
// Convert string to *url.URL
|
||||
return url.Parse(releaseURL)
|
||||
}
|
||||
|
||||
// showDownloadLocalAIDialog shows a dialog offering to download LocalAI
|
||||
func (l *Launcher) showDownloadLocalAIDialog() {
|
||||
if l.app == nil {
|
||||
log.Printf("Cannot show download dialog: app is nil")
|
||||
return
|
||||
}
|
||||
|
||||
fyne.DoAndWait(func() {
|
||||
// Create a standalone window for the download dialog
|
||||
dialogWindow := l.app.NewWindow("LocalAI Installation Required")
|
||||
dialogWindow.Resize(fyne.NewSize(500, 350))
|
||||
dialogWindow.CenterOnScreen()
|
||||
dialogWindow.SetCloseIntercept(func() {
|
||||
dialogWindow.Close()
|
||||
})
|
||||
|
||||
// Create the dialog content
|
||||
titleLabel := widget.NewLabel("LocalAI Not Found")
|
||||
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
|
||||
titleLabel.Alignment = fyne.TextAlignCenter
|
||||
|
||||
messageLabel := widget.NewLabel("LocalAI is not installed on your system.\n\nWould you like to download and install the latest version?")
|
||||
messageLabel.Wrapping = fyne.TextWrapWord
|
||||
messageLabel.Alignment = fyne.TextAlignCenter
|
||||
|
||||
// Buttons
|
||||
downloadButton := widget.NewButton("Download & Install", func() {
|
||||
dialogWindow.Close()
|
||||
l.downloadAndInstallLocalAI()
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
})
|
||||
downloadButton.Importance = widget.HighImportance
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
// Get latest release info and open release notes
|
||||
go func() {
|
||||
release, err := l.releaseManager.GetLatestRelease()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get latest release info: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
releaseNotesURL, err := l.githubReleaseNotesURL(release.Version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
l.app.OpenURL(releaseNotesURL)
|
||||
}()
|
||||
})
|
||||
|
||||
skipButton := widget.NewButton("Skip for Now", func() {
|
||||
dialogWindow.Close()
|
||||
})
|
||||
|
||||
// Layout - put release notes button above the main action buttons
|
||||
actionButtons := container.NewHBox(skipButton, downloadButton)
|
||||
content := container.NewVBox(
|
||||
titleLabel,
|
||||
widget.NewSeparator(),
|
||||
messageLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
widget.NewSeparator(),
|
||||
actionButtons,
|
||||
)
|
||||
|
||||
dialogWindow.SetContent(content)
|
||||
dialogWindow.Show()
|
||||
})
|
||||
}
|
||||
|
||||
// downloadAndInstallLocalAI downloads and installs the latest LocalAI version
|
||||
func (l *Launcher) downloadAndInstallLocalAI() {
|
||||
if l.app == nil {
|
||||
log.Printf("Cannot download LocalAI: app is nil")
|
||||
return
|
||||
}
|
||||
|
||||
// First check what the latest version is
|
||||
go func() {
|
||||
log.Printf("Checking for latest LocalAI version...")
|
||||
available, version, err := l.CheckForUpdates()
|
||||
if err != nil {
|
||||
log.Printf("Failed to check for updates: %v", err)
|
||||
l.showDownloadError("Failed to check for latest version", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !available {
|
||||
log.Printf("No updates available, but LocalAI is not installed")
|
||||
l.showDownloadError("No Version Available", "Could not determine the latest LocalAI version. Please check your internet connection and try again.")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Latest version available: %s", version)
|
||||
// Show progress window with the specific version
|
||||
l.showDownloadProgress(version, fmt.Sprintf("Downloading LocalAI %s...", version))
|
||||
}()
|
||||
}
|
||||
|
||||
// showDownloadError shows an error dialog for download failures
|
||||
func (l *Launcher) showDownloadError(title, message string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create error window
|
||||
errorWindow := l.app.NewWindow("Download Error")
|
||||
errorWindow.Resize(fyne.NewSize(400, 200))
|
||||
errorWindow.CenterOnScreen()
|
||||
errorWindow.SetCloseIntercept(func() {
|
||||
errorWindow.Close()
|
||||
})
|
||||
|
||||
// Error content
|
||||
titleLabel := widget.NewLabel(title)
|
||||
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
|
||||
titleLabel.Alignment = fyne.TextAlignCenter
|
||||
|
||||
messageLabel := widget.NewLabel(message)
|
||||
messageLabel.Wrapping = fyne.TextWrapWord
|
||||
messageLabel.Alignment = fyne.TextAlignCenter
|
||||
|
||||
// Close button
|
||||
closeButton := widget.NewButton("Close", func() {
|
||||
errorWindow.Close()
|
||||
})
|
||||
|
||||
// Layout
|
||||
content := container.NewVBox(
|
||||
titleLabel,
|
||||
widget.NewSeparator(),
|
||||
messageLabel,
|
||||
widget.NewSeparator(),
|
||||
closeButton,
|
||||
)
|
||||
|
||||
errorWindow.SetContent(content)
|
||||
errorWindow.Show()
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a standalone progress window for downloading LocalAI
|
||||
func (l *Launcher) showDownloadProgress(version, title string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window
|
||||
progressWindow := l.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
progressWindow.SetCloseIntercept(func() {
|
||||
progressWindow.Close()
|
||||
})
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := l.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
l.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := l.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status and refresh systray menu
|
||||
l.updateStatus("LocalAI installed successfully")
|
||||
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// monitorLogs monitors the output of LocalAI and adds it to the log buffer
|
||||
func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
timestamp := time.Now().Format("15:04:05")
|
||||
logLine := fmt.Sprintf("[%s] %s: %s\n", timestamp, prefix, line)
|
||||
|
||||
l.logMutex.Lock()
|
||||
l.logBuffer.WriteString(logLine)
|
||||
// Keep log buffer size reasonable
|
||||
if l.logBuffer.Len() > 100000 { // 100KB
|
||||
content := l.logBuffer.String()
|
||||
// Keep last 50KB
|
||||
if len(content) > 50000 {
|
||||
l.logBuffer.Reset()
|
||||
l.logBuffer.WriteString(content[len(content)-50000:])
|
||||
}
|
||||
}
|
||||
l.logMutex.Unlock()
|
||||
|
||||
// Write to log file if available
|
||||
if l.logFile != nil {
|
||||
if _, err := l.logFile.WriteString(logLine); err != nil {
|
||||
log.Printf("Failed to write to log file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
fyne.Do(func() {
|
||||
// Notify UI of new log content
|
||||
if l.ui != nil {
|
||||
l.ui.OnLogUpdate(logLine)
|
||||
}
|
||||
|
||||
// Check for startup completion
|
||||
if strings.Contains(line, "API server listening") {
|
||||
l.updateStatus("LocalAI is running")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus updates the status and notifies UI
|
||||
func (l *Launcher) updateStatus(status string) {
|
||||
select {
|
||||
case l.statusChannel <- status:
|
||||
default:
|
||||
// Channel full, skip
|
||||
}
|
||||
|
||||
if l.ui != nil {
|
||||
l.ui.UpdateStatus(status)
|
||||
}
|
||||
|
||||
if l.systray != nil {
|
||||
l.systray.UpdateStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// updateRunningState updates the running state in UI and systray
|
||||
func (l *Launcher) updateRunningState(isRunning bool) {
|
||||
if l.ui != nil {
|
||||
l.ui.UpdateRunningState(isRunning)
|
||||
}
|
||||
|
||||
if l.systray != nil {
|
||||
l.systray.UpdateRunningState(isRunning)
|
||||
}
|
||||
}
|
||||
|
||||
// periodicUpdateCheck checks for updates periodically
|
||||
func (l *Launcher) periodicUpdateCheck() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
available, version, err := l.CheckForUpdates()
|
||||
if err == nil && available {
|
||||
fyne.Do(func() {
|
||||
l.updateStatus(fmt.Sprintf("Update available: %s", version))
|
||||
if l.systray != nil {
|
||||
l.systray.NotifyUpdateAvailable(version)
|
||||
}
|
||||
if l.ui != nil {
|
||||
l.ui.NotifyUpdateAvailable(version)
|
||||
}
|
||||
})
|
||||
}
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadConfig loads configuration from file
|
||||
func (l *Launcher) loadConfig() error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(homeDir, ".localai", "launcher.json")
|
||||
log.Printf("Loading config from: %s", configPath)
|
||||
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
log.Printf("Config file not found, creating default config")
|
||||
// Create default config
|
||||
return l.saveConfig()
|
||||
}
|
||||
|
||||
// Load existing config
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Config file content: %s", string(configData))
|
||||
|
||||
log.Printf("loadConfig: about to unmarshal JSON data")
|
||||
if err := json.Unmarshal(configData, l.config); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
log.Printf("loadConfig: JSON unmarshaled successfully")
|
||||
|
||||
log.Printf("Loaded config: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
|
||||
l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
|
||||
log.Printf("Environment vars: %v", l.config.EnvironmentVars)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveConfig saves configuration to file
|
||||
func (l *Launcher) saveConfig() error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".localai")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Marshal config to JSON
|
||||
log.Printf("saveConfig: marshaling config with EnvironmentVars: %v", l.config.EnvironmentVars)
|
||||
configData, err := json.MarshalIndent(l.config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
log.Printf("saveConfig: JSON marshaled successfully, length: %d", len(configData))
|
||||
|
||||
configPath := filepath.Join(configDir, "launcher.json")
|
||||
log.Printf("Saving config to: %s", configPath)
|
||||
log.Printf("Config content: %s", string(configData))
|
||||
|
||||
if err := os.WriteFile(configPath, configData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Config saved successfully")
|
||||
return nil
|
||||
}
|
||||
13
cmd/launcher/internal/launcher_suite_test.go
Normal file
13
cmd/launcher/internal/launcher_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package launcher_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLauncher(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Launcher Suite")
|
||||
}
|
||||
213
cmd/launcher/internal/launcher_test.go
Normal file
213
cmd/launcher/internal/launcher_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package launcher_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"fyne.io/fyne/v2/app"
|
||||
|
||||
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||
)
|
||||
|
||||
var _ = Describe("Launcher", func() {
|
||||
var (
|
||||
launcherInstance *launcher.Launcher
|
||||
tempDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "launcher-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ui := launcher.NewLauncherUI()
|
||||
app := app.NewWithID("com.localai.launcher")
|
||||
|
||||
launcherInstance = launcher.NewLauncher(ui, nil, app)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Describe("NewLauncher", func() {
|
||||
It("should create a launcher with default configuration", func() {
|
||||
Expect(launcherInstance.GetConfig()).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Initialize", func() {
|
||||
It("should set default paths when not configured", func() {
|
||||
err := launcherInstance.Initialize()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
config := launcherInstance.GetConfig()
|
||||
Expect(config.ModelsPath).ToNot(BeEmpty())
|
||||
Expect(config.BackendsPath).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("should set default ShowWelcome to true", func() {
|
||||
err := launcherInstance.Initialize()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
config := launcherInstance.GetConfig()
|
||||
Expect(config.ShowWelcome).To(BeTrue())
|
||||
Expect(config.Address).To(Equal("127.0.0.1:8080"))
|
||||
Expect(config.LogLevel).To(Equal("info"))
|
||||
})
|
||||
|
||||
It("should create models and backends directories", func() {
|
||||
// Set custom paths for testing
|
||||
config := launcherInstance.GetConfig()
|
||||
config.ModelsPath = filepath.Join(tempDir, "models")
|
||||
config.BackendsPath = filepath.Join(tempDir, "backends")
|
||||
launcherInstance.SetConfig(config)
|
||||
|
||||
err := launcherInstance.Initialize()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Check if directories were created
|
||||
_, err = os.Stat(config.ModelsPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = os.Stat(config.BackendsPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration", func() {
|
||||
It("should get and set configuration", func() {
|
||||
config := launcherInstance.GetConfig()
|
||||
config.ModelsPath = "/test/models"
|
||||
config.BackendsPath = "/test/backends"
|
||||
config.Address = ":9090"
|
||||
config.LogLevel = "debug"
|
||||
|
||||
err := launcherInstance.SetConfig(config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
retrievedConfig := launcherInstance.GetConfig()
|
||||
Expect(retrievedConfig.ModelsPath).To(Equal("/test/models"))
|
||||
Expect(retrievedConfig.BackendsPath).To(Equal("/test/backends"))
|
||||
Expect(retrievedConfig.Address).To(Equal(":9090"))
|
||||
Expect(retrievedConfig.LogLevel).To(Equal("debug"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("WebUI URL", func() {
|
||||
It("should return correct WebUI URL for localhost", func() {
|
||||
config := launcherInstance.GetConfig()
|
||||
config.Address = ":8080"
|
||||
launcherInstance.SetConfig(config)
|
||||
|
||||
url := launcherInstance.GetWebUIURL()
|
||||
Expect(url).To(Equal("http://localhost:8080"))
|
||||
})
|
||||
|
||||
It("should return correct WebUI URL for full address", func() {
|
||||
config := launcherInstance.GetConfig()
|
||||
config.Address = "127.0.0.1:8080"
|
||||
launcherInstance.SetConfig(config)
|
||||
|
||||
url := launcherInstance.GetWebUIURL()
|
||||
Expect(url).To(Equal("http://127.0.0.1:8080"))
|
||||
})
|
||||
|
||||
It("should handle http prefix correctly", func() {
|
||||
config := launcherInstance.GetConfig()
|
||||
config.Address = "http://localhost:8080"
|
||||
launcherInstance.SetConfig(config)
|
||||
|
||||
url := launcherInstance.GetWebUIURL()
|
||||
Expect(url).To(Equal("http://localhost:8080"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Process Management", func() {
|
||||
It("should not be running initially", func() {
|
||||
Expect(launcherInstance.IsRunning()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should handle start when binary doesn't exist", func() {
|
||||
err := launcherInstance.StartLocalAI()
|
||||
Expect(err).To(HaveOccurred())
|
||||
// Could be either "not found" or "permission denied" depending on test environment
|
||||
errMsg := err.Error()
|
||||
hasExpectedError := strings.Contains(errMsg, "LocalAI binary") ||
|
||||
strings.Contains(errMsg, "permission denied")
|
||||
Expect(hasExpectedError).To(BeTrue(), "Expected error about binary not found or permission denied, got: %s", errMsg)
|
||||
})
|
||||
|
||||
It("should handle stop when not running", func() {
|
||||
err := launcherInstance.StopLocalAI()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LocalAI is not running"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Logs", func() {
|
||||
It("should return empty logs initially", func() {
|
||||
logs := launcherInstance.GetLogs()
|
||||
Expect(logs).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Version Management", func() {
|
||||
It("should return empty version when no binary installed", func() {
|
||||
version := launcherInstance.GetCurrentVersion()
|
||||
Expect(version).To(BeEmpty()) // No binary installed in test environment
|
||||
})
|
||||
|
||||
It("should handle update checks", func() {
|
||||
// This test would require mocking HTTP responses
|
||||
// For now, we'll just test that the method doesn't panic
|
||||
_, _, err := launcherInstance.CheckForUpdates()
|
||||
// We expect either success or a network error, not a panic
|
||||
if err != nil {
|
||||
// Network error is acceptable in tests
|
||||
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Config", func() {
|
||||
It("should have proper JSON tags", func() {
|
||||
config := &launcher.Config{
|
||||
ModelsPath: "/test/models",
|
||||
BackendsPath: "/test/backends",
|
||||
Address: ":8080",
|
||||
AutoStart: true,
|
||||
LogLevel: "info",
|
||||
EnvironmentVars: map[string]string{"TEST": "value"},
|
||||
}
|
||||
|
||||
Expect(config.ModelsPath).To(Equal("/test/models"))
|
||||
Expect(config.BackendsPath).To(Equal("/test/backends"))
|
||||
Expect(config.Address).To(Equal(":8080"))
|
||||
Expect(config.AutoStart).To(BeTrue())
|
||||
Expect(config.LogLevel).To(Equal("info"))
|
||||
Expect(config.EnvironmentVars).To(HaveKeyWithValue("TEST", "value"))
|
||||
})
|
||||
|
||||
It("should initialize environment variables map", func() {
|
||||
config := &launcher.Config{}
|
||||
Expect(config.EnvironmentVars).To(BeNil())
|
||||
|
||||
ui := launcher.NewLauncherUI()
|
||||
app := app.NewWithID("com.localai.launcher")
|
||||
|
||||
launcher := launcher.NewLauncher(ui, nil, app)
|
||||
|
||||
err := launcher.Initialize()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
retrievedConfig := launcher.GetConfig()
|
||||
Expect(retrievedConfig.EnvironmentVars).ToNot(BeNil())
|
||||
Expect(retrievedConfig.EnvironmentVars).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
502
cmd/launcher/internal/release_manager.go
Normal file
502
cmd/launcher/internal/release_manager.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package launcher
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
// Release represents a LocalAI release
|
||||
type Release struct {
|
||||
Version string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt time.Time `json:"published_at"`
|
||||
Assets []Asset `json:"assets"`
|
||||
}
|
||||
|
||||
// Asset represents a release asset
|
||||
type Asset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ReleaseManager handles LocalAI release management
|
||||
type ReleaseManager struct {
|
||||
// GitHubOwner is the GitHub repository owner
|
||||
GitHubOwner string
|
||||
// GitHubRepo is the GitHub repository name
|
||||
GitHubRepo string
|
||||
// BinaryPath is where the LocalAI binary is stored locally
|
||||
BinaryPath string
|
||||
// CurrentVersion is the currently installed version
|
||||
CurrentVersion string
|
||||
// ChecksumsPath is where checksums are stored
|
||||
ChecksumsPath string
|
||||
// MetadataPath is where version metadata is stored
|
||||
MetadataPath string
|
||||
}
|
||||
|
||||
// NewReleaseManager creates a new release manager
|
||||
func NewReleaseManager() *ReleaseManager {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
binaryPath := filepath.Join(homeDir, ".localai", "bin")
|
||||
checksumsPath := filepath.Join(homeDir, ".localai", "checksums")
|
||||
metadataPath := filepath.Join(homeDir, ".localai", "metadata")
|
||||
|
||||
return &ReleaseManager{
|
||||
GitHubOwner: "mudler",
|
||||
GitHubRepo: "LocalAI",
|
||||
BinaryPath: binaryPath,
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLatestRelease fetches the latest release information from GitHub
|
||||
func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the JSON response properly
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
release := &Release{}
|
||||
if err := json.Unmarshal(body, release); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
// Validate the release data
|
||||
if release.Version == "" {
|
||||
return nil, fmt.Errorf("no version found in release data")
|
||||
}
|
||||
|
||||
return release, nil
|
||||
}
|
||||
|
||||
// DownloadRelease downloads a specific version of LocalAI
|
||||
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
|
||||
// Ensure the binary directory exists
|
||||
if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create binary directory: %w", err)
|
||||
}
|
||||
|
||||
// Determine the binary name based on OS and architecture
|
||||
binaryName := rm.GetBinaryName(version)
|
||||
localPath := filepath.Join(rm.BinaryPath, "local-ai")
|
||||
|
||||
// Download the binary
|
||||
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
|
||||
|
||||
if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
|
||||
return fmt.Errorf("failed to download binary: %w", err)
|
||||
}
|
||||
|
||||
// Download and verify checksums
|
||||
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, version)
|
||||
|
||||
checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
|
||||
if err := rm.downloadFile(checksumURL, checksumPath, nil); err != nil {
|
||||
return fmt.Errorf("failed to download checksums: %w", err)
|
||||
}
|
||||
|
||||
// Verify the checksum
|
||||
if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
|
||||
return fmt.Errorf("checksum verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Save checksums persistently for future verification
|
||||
if err := rm.saveChecksums(version, checksumPath, binaryName); err != nil {
|
||||
log.Printf("Warning: failed to save checksums: %v", err)
|
||||
}
|
||||
|
||||
// Make the binary executable
|
||||
if err := os.Chmod(localPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to make binary executable: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBinaryName returns the appropriate binary name for the current platform
|
||||
func (rm *ReleaseManager) GetBinaryName(version string) string {
|
||||
versionStr := strings.TrimPrefix(version, "v")
|
||||
os := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
// Map Go arch names to the release naming convention
|
||||
switch arch {
|
||||
case "amd64":
|
||||
arch = "amd64"
|
||||
case "arm64":
|
||||
arch = "arm64"
|
||||
default:
|
||||
arch = "amd64" // fallback
|
||||
}
|
||||
|
||||
return fmt.Sprintf("local-ai-v%s-%s-%s", versionStr, os, arch)
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from a URL to a local path with optional progress callback
|
||||
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
out, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Create a progress reader if callback is provided
|
||||
var reader io.Reader = resp.Body
|
||||
if progressCallback != nil && resp.ContentLength > 0 {
|
||||
reader = &progressReader{
|
||||
Reader: resp.Body,
|
||||
Total: resp.ContentLength,
|
||||
Callback: progressCallback,
|
||||
}
|
||||
}
|
||||
|
||||
_, err = io.Copy(out, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
// saveChecksums saves checksums persistently for future verification
|
||||
func (rm *ReleaseManager) saveChecksums(version, checksumPath, binaryName string) error {
|
||||
// Ensure checksums directory exists
|
||||
if err := os.MkdirAll(rm.ChecksumsPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create checksums directory: %w", err)
|
||||
}
|
||||
|
||||
// Read the downloaded checksums file
|
||||
checksumData, err := os.ReadFile(checksumPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read checksums file: %w", err)
|
||||
}
|
||||
|
||||
// Save to persistent location with version info
|
||||
persistentPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
|
||||
if err := os.WriteFile(persistentPath, checksumData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write persistent checksums: %w", err)
|
||||
}
|
||||
|
||||
// Also save a "latest" checksums file for the current version
|
||||
latestPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
|
||||
if err := os.WriteFile(latestPath, checksumData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write latest checksums: %w", err)
|
||||
}
|
||||
|
||||
// Save version metadata
|
||||
if err := rm.saveVersionMetadata(version); err != nil {
|
||||
log.Printf("Warning: failed to save version metadata: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Checksums saved for version %s", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveVersionMetadata saves the installed version information
|
||||
func (rm *ReleaseManager) saveVersionMetadata(version string) error {
|
||||
// Ensure metadata directory exists
|
||||
if err := os.MkdirAll(rm.MetadataPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create metadata directory: %w", err)
|
||||
}
|
||||
|
||||
// Create metadata structure
|
||||
metadata := struct {
|
||||
Version string `json:"version"`
|
||||
InstalledAt time.Time `json:"installed_at"`
|
||||
BinaryPath string `json:"binary_path"`
|
||||
}{
|
||||
Version: version,
|
||||
InstalledAt: time.Now(),
|
||||
BinaryPath: rm.GetBinaryPath(),
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
metadataData, err := json.MarshalIndent(metadata, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Save metadata file
|
||||
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
|
||||
if err := os.WriteFile(metadataPath, metadataData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write metadata file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Version metadata saved: %s", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader to provide download progress
|
||||
type progressReader struct {
|
||||
io.Reader
|
||||
Total int64
|
||||
Current int64
|
||||
Callback func(float64)
|
||||
}
|
||||
|
||||
func (pr *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := pr.Reader.Read(p)
|
||||
pr.Current += int64(n)
|
||||
if pr.Callback != nil {
|
||||
progress := float64(pr.Current) / float64(pr.Total)
|
||||
pr.Callback(progress)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// VerifyChecksum verifies the downloaded file against the provided checksums
|
||||
func (rm *ReleaseManager) VerifyChecksum(filePath, checksumPath, binaryName string) error {
|
||||
// Calculate the SHA256 of the downloaded file
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file for checksum: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(hasher, file); err != nil {
|
||||
return fmt.Errorf("failed to calculate checksum: %w", err)
|
||||
}
|
||||
|
||||
calculatedHash := hex.EncodeToString(hasher.Sum(nil))
|
||||
|
||||
// Read the checksums file
|
||||
checksumFile, err := os.Open(checksumPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open checksums file: %w", err)
|
||||
}
|
||||
defer checksumFile.Close()
|
||||
|
||||
scanner := bufio.NewScanner(checksumFile)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.Contains(line, binaryName) {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 {
|
||||
expectedHash := parts[0]
|
||||
if calculatedHash == expectedHash {
|
||||
return nil // Checksum verified
|
||||
}
|
||||
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, calculatedHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("checksum not found for %s", binaryName)
|
||||
}
|
||||
|
||||
// GetInstalledVersion returns the currently installed version
|
||||
func (rm *ReleaseManager) GetInstalledVersion() string {
|
||||
|
||||
// Fallback: Check if the LocalAI binary exists and try to get its version
|
||||
binaryPath := rm.GetBinaryPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
return "" // No version installed
|
||||
}
|
||||
|
||||
// try to get version from metadata
|
||||
if version := rm.loadVersionMetadata(); version != "" {
|
||||
return version
|
||||
}
|
||||
|
||||
// Try to run the binary to get the version (fallback method)
|
||||
version, err := exec.Command(binaryPath, "--version").Output()
|
||||
if err != nil {
|
||||
// If binary exists but --version fails, try to determine from filename or other means
|
||||
log.Printf("Binary exists but --version failed: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
stringVersion := strings.TrimSpace(string(version))
|
||||
stringVersion = strings.TrimRight(stringVersion, "\n")
|
||||
|
||||
return stringVersion
|
||||
}
|
||||
|
||||
// loadVersionMetadata loads the installed version from metadata file
|
||||
func (rm *ReleaseManager) loadVersionMetadata() string {
|
||||
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
|
||||
|
||||
// Check if metadata file exists
|
||||
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Read metadata file
|
||||
metadataData, err := os.ReadFile(metadataPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read metadata file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse metadata
|
||||
var metadata struct {
|
||||
Version string `json:"version"`
|
||||
InstalledAt time.Time `json:"installed_at"`
|
||||
BinaryPath string `json:"binary_path"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(metadataData, &metadata); err != nil {
|
||||
log.Printf("Failed to parse metadata file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Verify that the binary path in metadata matches current binary path
|
||||
if metadata.BinaryPath != rm.GetBinaryPath() {
|
||||
log.Printf("Binary path mismatch in metadata, ignoring")
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Printf("Loaded version from metadata: %s (installed at %s)", metadata.Version, metadata.InstalledAt.Format("2006-01-02 15:04:05"))
|
||||
return metadata.Version
|
||||
}
|
||||
|
||||
// GetBinaryPath returns the path to the LocalAI binary
|
||||
func (rm *ReleaseManager) GetBinaryPath() string {
|
||||
return filepath.Join(rm.BinaryPath, "local-ai")
|
||||
}
|
||||
|
||||
// IsUpdateAvailable checks if an update is available
|
||||
func (rm *ReleaseManager) IsUpdateAvailable() (bool, string, error) {
|
||||
log.Printf("IsUpdateAvailable: checking for updates...")
|
||||
|
||||
latest, err := rm.GetLatestRelease()
|
||||
if err != nil {
|
||||
log.Printf("IsUpdateAvailable: failed to get latest release: %v", err)
|
||||
return false, "", err
|
||||
}
|
||||
log.Printf("IsUpdateAvailable: latest release version: %s", latest.Version)
|
||||
|
||||
current := rm.GetInstalledVersion()
|
||||
log.Printf("IsUpdateAvailable: current installed version: %s", current)
|
||||
|
||||
if current == "" {
|
||||
// No version installed, offer to download latest
|
||||
log.Printf("IsUpdateAvailable: no version installed, offering latest: %s", latest.Version)
|
||||
return true, latest.Version, nil
|
||||
}
|
||||
|
||||
updateAvailable := latest.Version != current
|
||||
log.Printf("IsUpdateAvailable: update available: %v (latest: %s, current: %s)", updateAvailable, latest.Version, current)
|
||||
return updateAvailable, latest.Version, nil
|
||||
}
|
||||
|
||||
// IsLocalAIInstalled checks if LocalAI binary exists and is valid
|
||||
func (rm *ReleaseManager) IsLocalAIInstalled() bool {
|
||||
binaryPath := rm.GetBinaryPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the binary integrity
|
||||
if err := rm.VerifyInstalledBinary(); err != nil {
|
||||
log.Printf("Binary integrity check failed: %v", err)
|
||||
// Remove corrupted binary
|
||||
if removeErr := os.Remove(binaryPath); removeErr != nil {
|
||||
log.Printf("Failed to remove corrupted binary: %v", removeErr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// VerifyInstalledBinary verifies the installed binary against saved checksums
|
||||
func (rm *ReleaseManager) VerifyInstalledBinary() error {
|
||||
binaryPath := rm.GetBinaryPath()
|
||||
|
||||
// Check if we have saved checksums
|
||||
latestChecksumsPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
|
||||
if _, err := os.Stat(latestChecksumsPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("no saved checksums found")
|
||||
}
|
||||
|
||||
// Get the binary name for the current version from metadata
|
||||
currentVersion := rm.loadVersionMetadata()
|
||||
if currentVersion == "" {
|
||||
return fmt.Errorf("cannot determine current version from metadata")
|
||||
}
|
||||
|
||||
binaryName := rm.GetBinaryName(currentVersion)
|
||||
|
||||
// Verify against saved checksums
|
||||
return rm.VerifyChecksum(binaryPath, latestChecksumsPath, binaryName)
|
||||
}
|
||||
|
||||
// CleanupPartialDownloads removes any partial or corrupted downloads
|
||||
func (rm *ReleaseManager) CleanupPartialDownloads() error {
|
||||
binaryPath := rm.GetBinaryPath()
|
||||
|
||||
// Check if binary exists but is corrupted
|
||||
if _, err := os.Stat(binaryPath); err == nil {
|
||||
// Binary exists, verify it
|
||||
if verifyErr := rm.VerifyInstalledBinary(); verifyErr != nil {
|
||||
log.Printf("Found corrupted binary, removing: %v", verifyErr)
|
||||
if removeErr := os.Remove(binaryPath); removeErr != nil {
|
||||
log.Printf("Failed to remove corrupted binary: %v", removeErr)
|
||||
}
|
||||
// Clear metadata since binary is corrupted
|
||||
rm.clearVersionMetadata()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any temporary checksum files
|
||||
tempChecksumsPath := filepath.Join(rm.BinaryPath, "checksums.txt")
|
||||
if _, err := os.Stat(tempChecksumsPath); err == nil {
|
||||
if removeErr := os.Remove(tempChecksumsPath); removeErr != nil {
|
||||
log.Printf("Failed to remove temporary checksums: %v", removeErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearVersionMetadata clears the version metadata (used when binary is corrupted or removed)
|
||||
func (rm *ReleaseManager) clearVersionMetadata() {
|
||||
metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
|
||||
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("Failed to clear version metadata: %v", err)
|
||||
} else {
|
||||
log.Printf("Version metadata cleared")
|
||||
}
|
||||
}
|
||||
178
cmd/launcher/internal/release_manager_test.go
Normal file
178
cmd/launcher/internal/release_manager_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package launcher_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||
)
|
||||
|
||||
var _ = Describe("ReleaseManager", func() {
|
||||
var (
|
||||
rm *launcher.ReleaseManager
|
||||
tempDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "launcher-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rm = launcher.NewReleaseManager()
|
||||
// Override binary path for testing
|
||||
rm.BinaryPath = tempDir
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Describe("NewReleaseManager", func() {
|
||||
It("should create a release manager with correct defaults", func() {
|
||||
newRM := launcher.NewReleaseManager()
|
||||
Expect(newRM.GitHubOwner).To(Equal("mudler"))
|
||||
Expect(newRM.GitHubRepo).To(Equal("LocalAI"))
|
||||
Expect(newRM.BinaryPath).To(ContainSubstring(".localai"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetBinaryName", func() {
|
||||
It("should return correct binary name for current platform", func() {
|
||||
binaryName := rm.GetBinaryName("v3.4.0")
|
||||
expectedOS := runtime.GOOS
|
||||
expectedArch := runtime.GOARCH
|
||||
|
||||
expected := "local-ai-v3.4.0-" + expectedOS + "-" + expectedArch
|
||||
Expect(binaryName).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("should handle version with and without 'v' prefix", func() {
|
||||
withV := rm.GetBinaryName("v3.4.0")
|
||||
withoutV := rm.GetBinaryName("3.4.0")
|
||||
|
||||
// Both should produce the same result
|
||||
Expect(withV).To(Equal(withoutV))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetBinaryPath", func() {
|
||||
It("should return the correct binary path", func() {
|
||||
path := rm.GetBinaryPath()
|
||||
expected := filepath.Join(tempDir, "local-ai")
|
||||
Expect(path).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetInstalledVersion", func() {
|
||||
It("should return empty when no binary exists", func() {
|
||||
version := rm.GetInstalledVersion()
|
||||
Expect(version).To(BeEmpty()) // No binary installed in test
|
||||
})
|
||||
|
||||
It("should return empty version when binary exists but no metadata", func() {
|
||||
// Create a fake binary for testing
|
||||
err := os.MkdirAll(rm.BinaryPath, 0755)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
binaryPath := rm.GetBinaryPath()
|
||||
err = os.WriteFile(binaryPath, []byte("fake binary"), 0755)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
version := rm.GetInstalledVersion()
|
||||
Expect(version).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with mocked responses", func() {
|
||||
// Note: In a real implementation, we'd mock HTTP responses
|
||||
// For now, we'll test the structure and error handling
|
||||
|
||||
Describe("GetLatestRelease", func() {
|
||||
It("should handle network errors gracefully", func() {
|
||||
// This test would require mocking HTTP client
|
||||
// For demonstration, we're just testing the method exists
|
||||
_, err := rm.GetLatestRelease()
|
||||
// We expect either success or a network error, not a panic
|
||||
// In a real test, we'd mock the HTTP response
|
||||
if err != nil {
|
||||
Expect(err.Error()).To(ContainSubstring("failed to fetch"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("DownloadRelease", func() {
|
||||
It("should create binary directory if it doesn't exist", func() {
|
||||
// Remove the temp directory to test creation
|
||||
os.RemoveAll(tempDir)
|
||||
|
||||
// This will fail due to network, but should create the directory
|
||||
rm.DownloadRelease("v3.4.0", nil)
|
||||
|
||||
// Check if directory was created
|
||||
_, err := os.Stat(tempDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("VerifyChecksum functionality", func() {
|
||||
var (
|
||||
testFile string
|
||||
checksumFile string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
testFile = filepath.Join(tempDir, "test-binary")
|
||||
checksumFile = filepath.Join(tempDir, "checksums.txt")
|
||||
})
|
||||
|
||||
It("should verify checksums correctly", func() {
|
||||
// Create a test file with known content
|
||||
testContent := []byte("test content for checksum")
|
||||
err := os.WriteFile(testFile, testContent, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Calculate expected SHA256
|
||||
// This is a simplified test - in practice we'd use the actual checksum
|
||||
checksumContent := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 test-binary\n"
|
||||
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Test checksum verification
|
||||
// Note: This will fail because our content doesn't match the empty string hash
|
||||
// In a real test, we'd calculate the actual hash
|
||||
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
|
||||
// We expect this to fail since we're using a dummy checksum
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("checksum mismatch"))
|
||||
})
|
||||
|
||||
It("should handle missing checksum file", func() {
|
||||
// Create test file but no checksum file
|
||||
err := os.WriteFile(testFile, []byte("test"), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to open checksums file"))
|
||||
})
|
||||
|
||||
It("should handle missing binary in checksums", func() {
|
||||
// Create files but checksum doesn't contain our binary
|
||||
err := os.WriteFile(testFile, []byte("test"), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
checksumContent := "hash other-binary\n"
|
||||
err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("checksum not found"))
|
||||
})
|
||||
})
|
||||
})
|
||||
523
cmd/launcher/internal/systray_manager.go
Normal file
523
cmd/launcher/internal/systray_manager.go
Normal file
@@ -0,0 +1,523 @@
|
||||
package launcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/driver/desktop"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
)
|
||||
|
||||
// SystrayManager manages the system tray functionality
|
||||
type SystrayManager struct {
|
||||
launcher *Launcher
|
||||
window fyne.Window
|
||||
app fyne.App
|
||||
desk desktop.App
|
||||
|
||||
// Menu items that need dynamic updates
|
||||
startStopItem *fyne.MenuItem
|
||||
hasUpdateAvailable bool
|
||||
latestVersion string
|
||||
icon *fyne.StaticResource
|
||||
}
|
||||
|
||||
// NewSystrayManager creates a new systray manager
|
||||
func NewSystrayManager(launcher *Launcher, window fyne.Window, desktop desktop.App, app fyne.App, icon *fyne.StaticResource) *SystrayManager {
|
||||
sm := &SystrayManager{
|
||||
launcher: launcher,
|
||||
window: window,
|
||||
app: app,
|
||||
desk: desktop,
|
||||
icon: icon,
|
||||
}
|
||||
sm.setupMenu(desktop)
|
||||
return sm
|
||||
}
|
||||
|
||||
// setupMenu sets up the system tray menu
|
||||
func (sm *SystrayManager) setupMenu(desk desktop.App) {
|
||||
sm.desk = desk
|
||||
|
||||
// Create the start/stop toggle item
|
||||
sm.startStopItem = fyne.NewMenuItem("Start LocalAI", func() {
|
||||
sm.toggleLocalAI()
|
||||
})
|
||||
|
||||
desk.SetSystemTrayIcon(sm.icon)
|
||||
|
||||
// Initialize the menu state using recreateMenu
|
||||
sm.recreateMenu()
|
||||
}
|
||||
|
||||
// toggleLocalAI starts or stops LocalAI based on current state
|
||||
func (sm *SystrayManager) toggleLocalAI() {
|
||||
if sm.launcher.IsRunning() {
|
||||
go func() {
|
||||
if err := sm.launcher.StopLocalAI(); err != nil {
|
||||
log.Printf("Failed to stop LocalAI: %v", err)
|
||||
sm.showErrorDialog("Failed to Stop LocalAI", err.Error())
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
if err := sm.launcher.StartLocalAI(); err != nil {
|
||||
log.Printf("Failed to start LocalAI: %v", err)
|
||||
sm.showStartupErrorDialog(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// openWebUI opens the LocalAI WebUI in the default browser
|
||||
func (sm *SystrayManager) openWebUI() {
|
||||
if !sm.launcher.IsRunning() {
|
||||
return // LocalAI is not running
|
||||
}
|
||||
|
||||
webURL := sm.launcher.GetWebUIURL()
|
||||
if parsedURL, err := url.Parse(webURL); err == nil {
|
||||
sm.app.OpenURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
// openDocumentation opens the LocalAI documentation
|
||||
func (sm *SystrayManager) openDocumentation() {
|
||||
if parsedURL, err := url.Parse("https://localai.io"); err == nil {
|
||||
sm.app.OpenURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
// updateStartStopItem updates the start/stop menu item based on current state
|
||||
func (sm *SystrayManager) updateStartStopItem() {
|
||||
// Since Fyne menu items can't change text dynamically, we recreate the menu
|
||||
sm.recreateMenu()
|
||||
}
|
||||
|
||||
// recreateMenu recreates the entire menu with updated state
|
||||
func (sm *SystrayManager) recreateMenu() {
|
||||
if sm.desk == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the action based on LocalAI installation and running state
|
||||
var actionItem *fyne.MenuItem
|
||||
if !sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
|
||||
// LocalAI not installed - show install option
|
||||
actionItem = fyne.NewMenuItem("📥 Install Latest Version", func() {
|
||||
sm.launcher.showDownloadLocalAIDialog()
|
||||
})
|
||||
} else if sm.launcher.IsRunning() {
|
||||
// LocalAI is running - show stop option
|
||||
actionItem = fyne.NewMenuItem("🛑 Stop LocalAI", func() {
|
||||
sm.toggleLocalAI()
|
||||
})
|
||||
} else {
|
||||
// LocalAI is installed but not running - show start option
|
||||
actionItem = fyne.NewMenuItem("▶️ Start LocalAI", func() {
|
||||
sm.toggleLocalAI()
|
||||
})
|
||||
}
|
||||
|
||||
menuItems := []*fyne.MenuItem{}
|
||||
|
||||
// Add status at the top (clickable for details)
|
||||
status := sm.launcher.GetLastStatus()
|
||||
statusText := sm.truncateText(status, 30)
|
||||
statusItem := fyne.NewMenuItem("📊 Status: "+statusText, func() {
|
||||
sm.showStatusDetails(status, "")
|
||||
})
|
||||
menuItems = append(menuItems, statusItem)
|
||||
|
||||
// Only show version if LocalAI is installed
|
||||
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
|
||||
version := sm.launcher.GetCurrentVersion()
|
||||
versionText := sm.truncateText(version, 25)
|
||||
versionItem := fyne.NewMenuItem("🔧 Version: "+versionText, func() {
|
||||
sm.showStatusDetails(status, version)
|
||||
})
|
||||
menuItems = append(menuItems, versionItem)
|
||||
}
|
||||
|
||||
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
|
||||
|
||||
// Add update notification if available
|
||||
if sm.hasUpdateAvailable {
|
||||
updateItem := fyne.NewMenuItem("🔔 New version available ("+sm.latestVersion+")", func() {
|
||||
sm.downloadUpdate()
|
||||
})
|
||||
menuItems = append(menuItems, updateItem)
|
||||
menuItems = append(menuItems, fyne.NewMenuItemSeparator())
|
||||
}
|
||||
|
||||
// Core actions
|
||||
menuItems = append(menuItems,
|
||||
actionItem,
|
||||
)
|
||||
|
||||
// Only show WebUI option if LocalAI is installed
|
||||
if sm.launcher.GetReleaseManager().IsLocalAIInstalled() && sm.launcher.IsRunning() {
|
||||
menuItems = append(menuItems,
|
||||
fyne.NewMenuItem("Open WebUI", func() {
|
||||
sm.openWebUI()
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
menuItems = append(menuItems,
|
||||
fyne.NewMenuItemSeparator(),
|
||||
fyne.NewMenuItem("Check for Updates", func() {
|
||||
sm.checkForUpdates()
|
||||
}),
|
||||
fyne.NewMenuItemSeparator(),
|
||||
fyne.NewMenuItem("Settings", func() {
|
||||
sm.showSettings()
|
||||
}),
|
||||
fyne.NewMenuItem("Show Welcome Window", func() {
|
||||
sm.showWelcomeWindow()
|
||||
}),
|
||||
fyne.NewMenuItem("Open Data Folder", func() {
|
||||
sm.openDataFolder()
|
||||
}),
|
||||
fyne.NewMenuItemSeparator(),
|
||||
fyne.NewMenuItem("Documentation", func() {
|
||||
sm.openDocumentation()
|
||||
}),
|
||||
fyne.NewMenuItemSeparator(),
|
||||
fyne.NewMenuItem("Quit", func() {
|
||||
// Perform cleanup before quitting
|
||||
if err := sm.launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
sm.app.Quit()
|
||||
}),
|
||||
)
|
||||
|
||||
menu := fyne.NewMenu("LocalAI", menuItems...)
|
||||
sm.desk.SetSystemTrayMenu(menu)
|
||||
}
|
||||
|
||||
// UpdateRunningState updates the systray based on running state
|
||||
func (sm *SystrayManager) UpdateRunningState(isRunning bool) {
|
||||
sm.updateStartStopItem()
|
||||
}
|
||||
|
||||
// UpdateStatus updates the systray menu to reflect status changes
|
||||
func (sm *SystrayManager) UpdateStatus(status string) {
|
||||
sm.recreateMenu()
|
||||
}
|
||||
|
||||
// checkForUpdates checks for available updates
|
||||
func (sm *SystrayManager) checkForUpdates() {
|
||||
go func() {
|
||||
log.Printf("Checking for updates...")
|
||||
available, version, err := sm.launcher.CheckForUpdates()
|
||||
if err != nil {
|
||||
log.Printf("Failed to check for updates: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Update check result: available=%v, version=%s", available, version)
|
||||
if available {
|
||||
sm.hasUpdateAvailable = true
|
||||
sm.latestVersion = version
|
||||
sm.recreateMenu()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// downloadUpdate downloads the latest update
|
||||
func (sm *SystrayManager) downloadUpdate() {
|
||||
if !sm.hasUpdateAvailable {
|
||||
return
|
||||
}
|
||||
|
||||
// Show progress window
|
||||
sm.showDownloadProgress(sm.latestVersion)
|
||||
}
|
||||
|
||||
// showSettings shows the settings window
|
||||
func (sm *SystrayManager) showSettings() {
|
||||
sm.window.Show()
|
||||
sm.window.RequestFocus()
|
||||
}
|
||||
|
||||
// showWelcomeWindow shows the welcome window
|
||||
func (sm *SystrayManager) showWelcomeWindow() {
|
||||
if sm.launcher.GetUI() != nil {
|
||||
sm.launcher.GetUI().ShowWelcomeWindow()
|
||||
}
|
||||
}
|
||||
|
||||
// openDataFolder opens the data folder in file manager
|
||||
func (sm *SystrayManager) openDataFolder() {
|
||||
dataPath := sm.launcher.GetDataPath()
|
||||
if parsedURL, err := url.Parse("file://" + dataPath); err == nil {
|
||||
sm.app.OpenURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyUpdateAvailable sets update notification in systray
|
||||
func (sm *SystrayManager) NotifyUpdateAvailable(version string) {
|
||||
sm.hasUpdateAvailable = true
|
||||
sm.latestVersion = version
|
||||
sm.recreateMenu()
|
||||
}
|
||||
|
||||
// truncateText truncates text to specified length and adds ellipsis if needed
|
||||
func (sm *SystrayManager) truncateText(text string, maxLength int) string {
|
||||
if len(text) <= maxLength {
|
||||
return text
|
||||
}
|
||||
return text[:maxLength-3] + "..."
|
||||
}
|
||||
|
||||
// showStatusDetails shows a detailed status window with full information
|
||||
func (sm *SystrayManager) showStatusDetails(status, version string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create status details window
|
||||
statusWindow := sm.app.NewWindow("LocalAI Status Details")
|
||||
statusWindow.Resize(fyne.NewSize(500, 400))
|
||||
statusWindow.CenterOnScreen()
|
||||
|
||||
// Status information
|
||||
statusLabel := widget.NewLabel("Current Status:")
|
||||
statusValue := widget.NewLabel(status)
|
||||
statusValue.Wrapping = fyne.TextWrapWord
|
||||
|
||||
// Version information (only show if version exists)
|
||||
var versionContainer fyne.CanvasObject
|
||||
if version != "" {
|
||||
versionLabel := widget.NewLabel("Installed Version:")
|
||||
versionValue := widget.NewLabel(version)
|
||||
versionValue.Wrapping = fyne.TextWrapWord
|
||||
versionContainer = container.NewVBox(versionLabel, versionValue)
|
||||
}
|
||||
|
||||
// Running state
|
||||
runningLabel := widget.NewLabel("Running State:")
|
||||
runningValue := widget.NewLabel("")
|
||||
if sm.launcher.IsRunning() {
|
||||
runningValue.SetText("🟢 Running")
|
||||
} else {
|
||||
runningValue.SetText("🔴 Stopped")
|
||||
}
|
||||
|
||||
// WebUI URL
|
||||
webuiLabel := widget.NewLabel("WebUI URL:")
|
||||
webuiValue := widget.NewLabel(sm.launcher.GetWebUIURL())
|
||||
webuiValue.Wrapping = fyne.TextWrapWord
|
||||
|
||||
// Recent logs (last 20 lines)
|
||||
logsLabel := widget.NewLabel("Recent Logs:")
|
||||
logsText := widget.NewMultiLineEntry()
|
||||
logsText.SetText(sm.launcher.GetRecentLogs())
|
||||
logsText.Wrapping = fyne.TextWrapWord
|
||||
logsText.Disable() // Make it read-only
|
||||
|
||||
// Buttons
|
||||
closeButton := widget.NewButton("Close", func() {
|
||||
statusWindow.Close()
|
||||
})
|
||||
|
||||
refreshButton := widget.NewButton("Refresh", func() {
|
||||
// Refresh the status information
|
||||
statusValue.SetText(sm.launcher.GetLastStatus())
|
||||
|
||||
// Note: Version refresh is not implemented for simplicity
|
||||
// The version will be updated when the status details window is reopened
|
||||
|
||||
if sm.launcher.IsRunning() {
|
||||
runningValue.SetText("🟢 Running")
|
||||
} else {
|
||||
runningValue.SetText("🔴 Stopped")
|
||||
}
|
||||
logsText.SetText(sm.launcher.GetRecentLogs())
|
||||
})
|
||||
|
||||
openWebUIButton := widget.NewButton("Open WebUI", func() {
|
||||
sm.openWebUI()
|
||||
})
|
||||
|
||||
// Layout
|
||||
buttons := container.NewHBox(closeButton, refreshButton, openWebUIButton)
|
||||
|
||||
// Build info container dynamically
|
||||
infoItems := []fyne.CanvasObject{
|
||||
statusLabel, statusValue,
|
||||
widget.NewSeparator(),
|
||||
}
|
||||
|
||||
// Add version section if it exists
|
||||
if versionContainer != nil {
|
||||
infoItems = append(infoItems, versionContainer, widget.NewSeparator())
|
||||
}
|
||||
|
||||
infoItems = append(infoItems,
|
||||
runningLabel, runningValue,
|
||||
widget.NewSeparator(),
|
||||
webuiLabel, webuiValue,
|
||||
)
|
||||
|
||||
infoContainer := container.NewVBox(infoItems...)
|
||||
|
||||
content := container.NewVBox(
|
||||
infoContainer,
|
||||
widget.NewSeparator(),
|
||||
logsLabel,
|
||||
logsText,
|
||||
widget.NewSeparator(),
|
||||
buttons,
|
||||
)
|
||||
|
||||
statusWindow.SetContent(content)
|
||||
statusWindow.Show()
|
||||
})
|
||||
}
|
||||
|
||||
// showErrorDialog shows a simple error dialog
|
||||
func (sm *SystrayManager) showErrorDialog(title, message string) {
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowError(fmt.Errorf(message), sm.window)
|
||||
})
|
||||
}
|
||||
|
||||
// showStartupErrorDialog shows a detailed error dialog with process logs
|
||||
func (sm *SystrayManager) showStartupErrorDialog(err error) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Get the recent process logs (more useful for debugging)
|
||||
logs := sm.launcher.GetRecentLogs()
|
||||
|
||||
// Create error window
|
||||
errorWindow := sm.app.NewWindow("LocalAI Startup Failed")
|
||||
errorWindow.Resize(fyne.NewSize(600, 500))
|
||||
errorWindow.CenterOnScreen()
|
||||
|
||||
// Error message
|
||||
errorLabel := widget.NewLabel(fmt.Sprintf("Failed to start LocalAI:\n%s", err.Error()))
|
||||
errorLabel.Wrapping = fyne.TextWrapWord
|
||||
|
||||
// Logs display
|
||||
logsLabel := widget.NewLabel("Process Logs:")
|
||||
logsText := widget.NewMultiLineEntry()
|
||||
logsText.SetText(logs)
|
||||
logsText.Wrapping = fyne.TextWrapWord
|
||||
logsText.Disable() // Make it read-only
|
||||
|
||||
// Buttons
|
||||
closeButton := widget.NewButton("Close", func() {
|
||||
errorWindow.Close()
|
||||
})
|
||||
|
||||
retryButton := widget.NewButton("Retry", func() {
|
||||
errorWindow.Close()
|
||||
// Try to start again
|
||||
go func() {
|
||||
if retryErr := sm.launcher.StartLocalAI(); retryErr != nil {
|
||||
sm.showStartupErrorDialog(retryErr)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
openLogsButton := widget.NewButton("Open Logs Folder", func() {
|
||||
sm.openDataFolder()
|
||||
})
|
||||
|
||||
// Layout
|
||||
buttons := container.NewHBox(closeButton, retryButton, openLogsButton)
|
||||
content := container.NewVBox(
|
||||
errorLabel,
|
||||
widget.NewSeparator(),
|
||||
logsLabel,
|
||||
logsText,
|
||||
widget.NewSeparator(),
|
||||
buttons,
|
||||
)
|
||||
|
||||
errorWindow.SetContent(content)
|
||||
errorWindow.Show()
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading updates
|
||||
func (sm *SystrayManager) showDownloadProgress(version string) {
|
||||
// Create a new window for download progress
|
||||
progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := sm.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show restart dialog
|
||||
dialog.ShowConfirm("Update Downloaded",
|
||||
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
|
||||
func(restart bool) {
|
||||
if restart {
|
||||
sm.app.Quit()
|
||||
}
|
||||
progressWindow.Close()
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
|
||||
// Update systray menu
|
||||
if err == nil {
|
||||
sm.hasUpdateAvailable = false
|
||||
sm.latestVersion = ""
|
||||
sm.recreateMenu()
|
||||
}
|
||||
}()
|
||||
}
|
||||
795
cmd/launcher/internal/ui.go
Normal file
795
cmd/launcher/internal/ui.go
Normal file
@@ -0,0 +1,795 @@
|
||||
package launcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
)
|
||||
|
||||
// EnvVar represents an environment variable
|
||||
type EnvVar struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// LauncherUI handles the user interface
|
||||
type LauncherUI struct {
|
||||
// Status display
|
||||
statusLabel *widget.Label
|
||||
versionLabel *widget.Label
|
||||
|
||||
// Control buttons
|
||||
startStopButton *widget.Button
|
||||
webUIButton *widget.Button
|
||||
updateButton *widget.Button
|
||||
downloadButton *widget.Button
|
||||
|
||||
// Configuration
|
||||
modelsPathEntry *widget.Entry
|
||||
backendsPathEntry *widget.Entry
|
||||
addressEntry *widget.Entry
|
||||
logLevelSelect *widget.Select
|
||||
startOnBootCheck *widget.Check
|
||||
|
||||
// Environment Variables
|
||||
envVarsData []EnvVar
|
||||
newEnvKeyEntry *widget.Entry
|
||||
newEnvValueEntry *widget.Entry
|
||||
updateEnvironmentDisplay func()
|
||||
|
||||
// Logs
|
||||
logText *widget.Entry
|
||||
|
||||
// Progress
|
||||
progressBar *widget.ProgressBar
|
||||
|
||||
// Update management
|
||||
latestVersion string
|
||||
|
||||
// Reference to launcher
|
||||
launcher *Launcher
|
||||
}
|
||||
|
||||
// NewLauncherUI creates a new UI instance
|
||||
func NewLauncherUI() *LauncherUI {
|
||||
return &LauncherUI{
|
||||
statusLabel: widget.NewLabel("Initializing..."),
|
||||
versionLabel: widget.NewLabel("Version: Unknown"),
|
||||
startStopButton: widget.NewButton("Start LocalAI", nil),
|
||||
webUIButton: widget.NewButton("Open WebUI", nil),
|
||||
updateButton: widget.NewButton("Check for Updates", nil),
|
||||
modelsPathEntry: widget.NewEntry(),
|
||||
backendsPathEntry: widget.NewEntry(),
|
||||
addressEntry: widget.NewEntry(),
|
||||
logLevelSelect: widget.NewSelect([]string{"error", "warn", "info", "debug", "trace"}, nil),
|
||||
startOnBootCheck: widget.NewCheck("Start LocalAI on system boot", nil),
|
||||
logText: widget.NewMultiLineEntry(),
|
||||
progressBar: widget.NewProgressBar(),
|
||||
envVarsData: []EnvVar{}, // Initialize the environment variables slice
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMainUI creates the main UI layout
|
||||
func (ui *LauncherUI) CreateMainUI(launcher *Launcher) *fyne.Container {
|
||||
ui.launcher = launcher
|
||||
ui.setupBindings()
|
||||
|
||||
// Main tab with status and controls
|
||||
// Configuration is now the main content
|
||||
configTab := ui.createConfigTab()
|
||||
|
||||
// Create a simple container instead of tabs since we only have settings
|
||||
tabs := container.NewVBox(
|
||||
widget.NewCard("LocalAI Launcher Settings", "", configTab),
|
||||
)
|
||||
|
||||
return tabs
|
||||
}
|
||||
|
||||
// createConfigTab creates the configuration tab
|
||||
func (ui *LauncherUI) createConfigTab() *fyne.Container {
|
||||
// Path configuration
|
||||
pathsCard := widget.NewCard("Paths", "", container.NewGridWithColumns(2,
|
||||
widget.NewLabel("Models Path:"),
|
||||
ui.modelsPathEntry,
|
||||
widget.NewLabel("Backends Path:"),
|
||||
ui.backendsPathEntry,
|
||||
))
|
||||
|
||||
// Server configuration
|
||||
serverCard := widget.NewCard("Server", "", container.NewVBox(
|
||||
container.NewGridWithColumns(2,
|
||||
widget.NewLabel("Address:"),
|
||||
ui.addressEntry,
|
||||
widget.NewLabel("Log Level:"),
|
||||
ui.logLevelSelect,
|
||||
),
|
||||
ui.startOnBootCheck,
|
||||
))
|
||||
|
||||
// Save button
|
||||
saveButton := widget.NewButton("Save Configuration", func() {
|
||||
ui.saveConfiguration()
|
||||
})
|
||||
|
||||
// Environment Variables section
|
||||
envCard := ui.createEnvironmentSection()
|
||||
|
||||
return container.NewVBox(
|
||||
pathsCard,
|
||||
serverCard,
|
||||
envCard,
|
||||
saveButton,
|
||||
)
|
||||
}
|
||||
|
||||
// createEnvironmentSection creates the environment variables section for the config tab
|
||||
func (ui *LauncherUI) createEnvironmentSection() *fyne.Container {
|
||||
// Initialize environment variables widgets
|
||||
ui.newEnvKeyEntry = widget.NewEntry()
|
||||
ui.newEnvKeyEntry.SetPlaceHolder("Environment Variable Name")
|
||||
|
||||
ui.newEnvValueEntry = widget.NewEntry()
|
||||
ui.newEnvValueEntry.SetPlaceHolder("Environment Variable Value")
|
||||
|
||||
// Add button
|
||||
addButton := widget.NewButton("Add Environment Variable", func() {
|
||||
ui.addEnvironmentVariable()
|
||||
})
|
||||
|
||||
// Environment variables list with delete buttons
|
||||
ui.envVarsData = []EnvVar{}
|
||||
|
||||
// Create container for environment variables
|
||||
envVarsContainer := container.NewVBox()
|
||||
|
||||
// Update function to rebuild the environment variables display
|
||||
ui.updateEnvironmentDisplay = func() {
|
||||
envVarsContainer.Objects = nil
|
||||
for i, envVar := range ui.envVarsData {
|
||||
index := i // Capture index for closure
|
||||
|
||||
// Create row with label and delete button
|
||||
envLabel := widget.NewLabel(fmt.Sprintf("%s = %s", envVar.Key, envVar.Value))
|
||||
deleteBtn := widget.NewButton("Delete", func() {
|
||||
ui.confirmDeleteEnvironmentVariable(index)
|
||||
})
|
||||
deleteBtn.Importance = widget.DangerImportance
|
||||
|
||||
row := container.NewBorder(nil, nil, nil, deleteBtn, envLabel)
|
||||
envVarsContainer.Add(row)
|
||||
}
|
||||
envVarsContainer.Refresh()
|
||||
}
|
||||
|
||||
// Create a scrollable container for the environment variables
|
||||
envScroll := container.NewScroll(envVarsContainer)
|
||||
envScroll.SetMinSize(fyne.NewSize(400, 150))
|
||||
|
||||
// Input section for adding new environment variables
|
||||
inputSection := container.NewVBox(
|
||||
container.NewGridWithColumns(2,
|
||||
ui.newEnvKeyEntry,
|
||||
ui.newEnvValueEntry,
|
||||
),
|
||||
addButton,
|
||||
)
|
||||
|
||||
// Environment variables card
|
||||
envCard := widget.NewCard("Environment Variables", "", container.NewVBox(
|
||||
inputSection,
|
||||
widget.NewSeparator(),
|
||||
envScroll,
|
||||
))
|
||||
|
||||
return container.NewVBox(envCard)
|
||||
}
|
||||
|
||||
// addEnvironmentVariable adds a new environment variable
|
||||
func (ui *LauncherUI) addEnvironmentVariable() {
|
||||
key := ui.newEnvKeyEntry.Text
|
||||
value := ui.newEnvValueEntry.Text
|
||||
|
||||
log.Printf("addEnvironmentVariable: attempting to add %s=%s", key, value)
|
||||
log.Printf("addEnvironmentVariable: current ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
|
||||
|
||||
if key == "" {
|
||||
log.Printf("addEnvironmentVariable: key is empty, showing error")
|
||||
dialog.ShowError(fmt.Errorf("environment variable name cannot be empty"), ui.launcher.window)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if key already exists
|
||||
for _, envVar := range ui.envVarsData {
|
||||
if envVar.Key == key {
|
||||
log.Printf("addEnvironmentVariable: key %s already exists, showing error", key)
|
||||
dialog.ShowError(fmt.Errorf("environment variable '%s' already exists", key), ui.launcher.window)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("addEnvironmentVariable: adding new env var %s=%s", key, value)
|
||||
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
|
||||
log.Printf("addEnvironmentVariable: after adding, ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
|
||||
|
||||
fyne.Do(func() {
|
||||
if ui.updateEnvironmentDisplay != nil {
|
||||
ui.updateEnvironmentDisplay()
|
||||
}
|
||||
// Clear input fields
|
||||
ui.newEnvKeyEntry.SetText("")
|
||||
ui.newEnvValueEntry.SetText("")
|
||||
})
|
||||
|
||||
log.Printf("addEnvironmentVariable: calling saveEnvironmentVariables")
|
||||
// Save to configuration
|
||||
ui.saveEnvironmentVariables()
|
||||
}
|
||||
|
||||
// removeEnvironmentVariable removes an environment variable by index
|
||||
func (ui *LauncherUI) removeEnvironmentVariable(index int) {
|
||||
if index >= 0 && index < len(ui.envVarsData) {
|
||||
ui.envVarsData = append(ui.envVarsData[:index], ui.envVarsData[index+1:]...)
|
||||
fyne.Do(func() {
|
||||
if ui.updateEnvironmentDisplay != nil {
|
||||
ui.updateEnvironmentDisplay()
|
||||
}
|
||||
})
|
||||
ui.saveEnvironmentVariables()
|
||||
}
|
||||
}
|
||||
|
||||
// saveEnvironmentVariables saves environment variables to the configuration
|
||||
func (ui *LauncherUI) saveEnvironmentVariables() {
|
||||
if ui.launcher == nil {
|
||||
log.Printf("saveEnvironmentVariables: launcher is nil")
|
||||
return
|
||||
}
|
||||
|
||||
config := ui.launcher.GetConfig()
|
||||
log.Printf("saveEnvironmentVariables: before - Environment vars: %v", config.EnvironmentVars)
|
||||
|
||||
config.EnvironmentVars = make(map[string]string)
|
||||
for _, envVar := range ui.envVarsData {
|
||||
config.EnvironmentVars[envVar.Key] = envVar.Value
|
||||
log.Printf("saveEnvironmentVariables: adding %s=%s", envVar.Key, envVar.Value)
|
||||
}
|
||||
|
||||
log.Printf("saveEnvironmentVariables: after - Environment vars: %v", config.EnvironmentVars)
|
||||
log.Printf("saveEnvironmentVariables: calling SetConfig with %d environment variables", len(config.EnvironmentVars))
|
||||
|
||||
err := ui.launcher.SetConfig(config)
|
||||
if err != nil {
|
||||
log.Printf("saveEnvironmentVariables: failed to save config: %v", err)
|
||||
} else {
|
||||
log.Printf("saveEnvironmentVariables: config saved successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// confirmDeleteEnvironmentVariable shows confirmation dialog for deleting an environment variable
|
||||
func (ui *LauncherUI) confirmDeleteEnvironmentVariable(index int) {
|
||||
if index >= 0 && index < len(ui.envVarsData) {
|
||||
envVar := ui.envVarsData[index]
|
||||
dialog.ShowConfirm("Remove Environment Variable",
|
||||
fmt.Sprintf("Remove environment variable '%s'?", envVar.Key),
|
||||
func(remove bool) {
|
||||
if remove {
|
||||
ui.removeEnvironmentVariable(index)
|
||||
}
|
||||
}, ui.launcher.window)
|
||||
}
|
||||
}
|
||||
|
||||
// setupBindings sets up event handlers for UI elements
|
||||
func (ui *LauncherUI) setupBindings() {
|
||||
// Start/Stop button
|
||||
ui.startStopButton.OnTapped = func() {
|
||||
if ui.launcher.IsRunning() {
|
||||
ui.stopLocalAI()
|
||||
} else {
|
||||
ui.startLocalAI()
|
||||
}
|
||||
}
|
||||
|
||||
// WebUI button
|
||||
ui.webUIButton.OnTapped = func() {
|
||||
ui.openWebUI()
|
||||
}
|
||||
ui.webUIButton.Disable() // Disabled until LocalAI is running
|
||||
|
||||
// Update button
|
||||
ui.updateButton.OnTapped = func() {
|
||||
ui.checkForUpdates()
|
||||
}
|
||||
|
||||
// Log level selection
|
||||
ui.logLevelSelect.OnChanged = func(selected string) {
|
||||
if ui.launcher != nil {
|
||||
config := ui.launcher.GetConfig()
|
||||
config.LogLevel = selected
|
||||
ui.launcher.SetConfig(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startLocalAI starts the LocalAI service
|
||||
func (ui *LauncherUI) startLocalAI() {
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.Disable()
|
||||
})
|
||||
ui.UpdateStatus("Starting LocalAI...")
|
||||
|
||||
go func() {
|
||||
err := ui.launcher.StartLocalAI()
|
||||
if err != nil {
|
||||
ui.UpdateStatus("Failed to start: " + err.Error())
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
})
|
||||
} else {
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.SetText("Stop LocalAI")
|
||||
ui.webUIButton.Enable()
|
||||
})
|
||||
}
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.Enable()
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// stopLocalAI stops the LocalAI service
|
||||
func (ui *LauncherUI) stopLocalAI() {
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.Disable()
|
||||
})
|
||||
ui.UpdateStatus("Stopping LocalAI...")
|
||||
|
||||
go func() {
|
||||
err := ui.launcher.StopLocalAI()
|
||||
if err != nil {
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
})
|
||||
} else {
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.SetText("Start LocalAI")
|
||||
ui.webUIButton.Disable()
|
||||
})
|
||||
}
|
||||
fyne.Do(func() {
|
||||
ui.startStopButton.Enable()
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// openWebUI opens the LocalAI WebUI in the default browser
|
||||
func (ui *LauncherUI) openWebUI() {
|
||||
webURL := ui.launcher.GetWebUIURL()
|
||||
parsedURL, err := url.Parse(webURL)
|
||||
if err != nil {
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
return
|
||||
}
|
||||
|
||||
// Open URL in default browser
|
||||
fyne.CurrentApp().OpenURL(parsedURL)
|
||||
}
|
||||
|
||||
// saveConfiguration saves the current configuration
|
||||
func (ui *LauncherUI) saveConfiguration() {
|
||||
log.Printf("saveConfiguration: starting to save configuration")
|
||||
|
||||
config := ui.launcher.GetConfig()
|
||||
log.Printf("saveConfiguration: current config Environment vars: %v", config.EnvironmentVars)
|
||||
log.Printf("saveConfiguration: ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
|
||||
|
||||
config.ModelsPath = ui.modelsPathEntry.Text
|
||||
config.BackendsPath = ui.backendsPathEntry.Text
|
||||
config.Address = ui.addressEntry.Text
|
||||
config.LogLevel = ui.logLevelSelect.Selected
|
||||
config.StartOnBoot = ui.startOnBootCheck.Checked
|
||||
|
||||
// Ensure environment variables are included in the configuration
|
||||
config.EnvironmentVars = make(map[string]string)
|
||||
for _, envVar := range ui.envVarsData {
|
||||
config.EnvironmentVars[envVar.Key] = envVar.Value
|
||||
log.Printf("saveConfiguration: adding env var %s=%s", envVar.Key, envVar.Value)
|
||||
}
|
||||
|
||||
log.Printf("saveConfiguration: final config Environment vars: %v", config.EnvironmentVars)
|
||||
|
||||
err := ui.launcher.SetConfig(config)
|
||||
if err != nil {
|
||||
log.Printf("saveConfiguration: failed to save config: %v", err)
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
} else {
|
||||
log.Printf("saveConfiguration: config saved successfully")
|
||||
dialog.ShowInformation("Configuration", "Configuration saved successfully", ui.launcher.window)
|
||||
}
|
||||
}
|
||||
|
||||
// checkForUpdates checks for available updates
|
||||
func (ui *LauncherUI) checkForUpdates() {
|
||||
fyne.Do(func() {
|
||||
ui.updateButton.Disable()
|
||||
})
|
||||
ui.UpdateStatus("Checking for updates...")
|
||||
|
||||
go func() {
|
||||
available, version, err := ui.launcher.CheckForUpdates()
|
||||
if err != nil {
|
||||
ui.UpdateStatus("Failed to check updates: " + err.Error())
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
})
|
||||
} else if available {
|
||||
ui.latestVersion = version // Store the latest version
|
||||
ui.UpdateStatus("Update available: " + version)
|
||||
fyne.Do(func() {
|
||||
if ui.downloadButton != nil {
|
||||
ui.downloadButton.Enable()
|
||||
}
|
||||
})
|
||||
ui.NotifyUpdateAvailable(version)
|
||||
} else {
|
||||
ui.UpdateStatus("No updates available")
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowInformation("Updates", "You are running the latest version", ui.launcher.window)
|
||||
})
|
||||
}
|
||||
fyne.Do(func() {
|
||||
ui.updateButton.Enable()
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// downloadUpdate downloads the latest update
|
||||
func (ui *LauncherUI) downloadUpdate() {
|
||||
// Use stored version or check for updates
|
||||
version := ui.latestVersion
|
||||
if version == "" {
|
||||
_, v, err := ui.launcher.CheckForUpdates()
|
||||
if err != nil {
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
return
|
||||
}
|
||||
version = v
|
||||
ui.latestVersion = version
|
||||
}
|
||||
|
||||
if version == "" {
|
||||
dialog.ShowError(fmt.Errorf("no version information available"), ui.launcher.window)
|
||||
return
|
||||
}
|
||||
|
||||
// Disable buttons during download
|
||||
if ui.downloadButton != nil {
|
||||
fyne.Do(func() {
|
||||
ui.downloadButton.Disable()
|
||||
})
|
||||
}
|
||||
|
||||
fyne.Do(func() {
|
||||
ui.progressBar.Show()
|
||||
ui.progressBar.SetValue(0)
|
||||
})
|
||||
ui.UpdateStatus("Downloading update " + version + "...")
|
||||
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
ui.progressBar.SetValue(progress)
|
||||
})
|
||||
// Update status with percentage
|
||||
percentage := int(progress * 100)
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
|
||||
})
|
||||
|
||||
fyne.Do(func() {
|
||||
ui.progressBar.Hide()
|
||||
})
|
||||
|
||||
// Re-enable buttons after download
|
||||
if ui.downloadButton != nil {
|
||||
fyne.Do(func() {
|
||||
ui.downloadButton.Enable()
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fyne.DoAndWait(func() {
|
||||
ui.UpdateStatus("Failed to download update: " + err.Error())
|
||||
dialog.ShowError(err, ui.launcher.window)
|
||||
})
|
||||
} else {
|
||||
fyne.DoAndWait(func() {
|
||||
ui.UpdateStatus("Update downloaded successfully")
|
||||
dialog.ShowInformation("Update", "Update downloaded successfully. Please restart the launcher to use the new version.", ui.launcher.window)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status label
|
||||
func (ui *LauncherUI) UpdateStatus(status string) {
|
||||
if ui.statusLabel != nil {
|
||||
fyne.Do(func() {
|
||||
ui.statusLabel.SetText(status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OnLogUpdate handles new log content
|
||||
func (ui *LauncherUI) OnLogUpdate(logLine string) {
|
||||
if ui.logText != nil {
|
||||
fyne.Do(func() {
|
||||
currentText := ui.logText.Text
|
||||
ui.logText.SetText(currentText + logLine)
|
||||
|
||||
// Auto-scroll to bottom (simplified)
|
||||
ui.logText.CursorRow = len(ui.logText.Text)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyUpdateAvailable shows an update notification
|
||||
func (ui *LauncherUI) NotifyUpdateAvailable(version string) {
|
||||
if ui.launcher != nil && ui.launcher.window != nil {
|
||||
fyne.DoAndWait(func() {
|
||||
dialog.ShowConfirm("Update Available",
|
||||
"A new version ("+version+") is available. Would you like to download it?",
|
||||
func(confirmed bool) {
|
||||
if confirmed {
|
||||
ui.downloadUpdate()
|
||||
}
|
||||
}, ui.launcher.window)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfiguration loads the current configuration into UI elements
|
||||
func (ui *LauncherUI) LoadConfiguration() {
|
||||
if ui.launcher == nil {
|
||||
log.Printf("UI LoadConfiguration: launcher is nil")
|
||||
return
|
||||
}
|
||||
|
||||
config := ui.launcher.GetConfig()
|
||||
log.Printf("UI LoadConfiguration: loading config - ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
|
||||
config.ModelsPath, config.BackendsPath, config.Address, config.LogLevel)
|
||||
log.Printf("UI LoadConfiguration: Environment vars: %v", config.EnvironmentVars)
|
||||
|
||||
ui.modelsPathEntry.SetText(config.ModelsPath)
|
||||
ui.backendsPathEntry.SetText(config.BackendsPath)
|
||||
ui.addressEntry.SetText(config.Address)
|
||||
ui.logLevelSelect.SetSelected(config.LogLevel)
|
||||
ui.startOnBootCheck.SetChecked(config.StartOnBoot)
|
||||
|
||||
// Load environment variables
|
||||
ui.envVarsData = []EnvVar{}
|
||||
for key, value := range config.EnvironmentVars {
|
||||
ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
|
||||
}
|
||||
if ui.updateEnvironmentDisplay != nil {
|
||||
fyne.Do(func() {
|
||||
ui.updateEnvironmentDisplay()
|
||||
})
|
||||
}
|
||||
|
||||
// Update version display
|
||||
version := ui.launcher.GetCurrentVersion()
|
||||
ui.versionLabel.SetText("Version: " + version)
|
||||
|
||||
log.Printf("UI LoadConfiguration: configuration loaded successfully")
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading LocalAI
|
||||
func (ui *LauncherUI) showDownloadProgress(version, title string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window using the launcher's app
|
||||
progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ui.launcher.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status
|
||||
ui.UpdateStatus("LocalAI installed successfully")
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRunningState updates UI based on LocalAI running state
|
||||
func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
|
||||
fyne.Do(func() {
|
||||
if isRunning {
|
||||
ui.startStopButton.SetText("Stop LocalAI")
|
||||
ui.webUIButton.Enable()
|
||||
} else {
|
||||
ui.startStopButton.SetText("Start LocalAI")
|
||||
ui.webUIButton.Disable()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ShowWelcomeWindow displays the welcome window with helpful information
|
||||
func (ui *LauncherUI) ShowWelcomeWindow() {
|
||||
if ui.launcher == nil || ui.launcher.window == nil {
|
||||
log.Printf("Cannot show welcome window: launcher or window is nil")
|
||||
return
|
||||
}
|
||||
|
||||
fyne.DoAndWait(func() {
|
||||
// Create welcome window
|
||||
welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher")
|
||||
welcomeWindow.Resize(fyne.NewSize(600, 500))
|
||||
welcomeWindow.CenterOnScreen()
|
||||
welcomeWindow.SetCloseIntercept(func() {
|
||||
welcomeWindow.Close()
|
||||
})
|
||||
|
||||
// Title
|
||||
titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!")
|
||||
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
|
||||
titleLabel.Alignment = fyne.TextAlignCenter
|
||||
|
||||
// Welcome message
|
||||
welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system.
|
||||
|
||||
What you can do:
|
||||
• Start and stop LocalAI server
|
||||
• Configure models and backends paths
|
||||
• Set environment variables
|
||||
• Check for updates automatically
|
||||
• Access LocalAI WebUI when running
|
||||
|
||||
Getting Started:
|
||||
1. Configure your models and backends paths
|
||||
2. Click "Start LocalAI" to begin
|
||||
3. Use "Open WebUI" to access the interface
|
||||
4. Check the system tray for quick access`
|
||||
|
||||
welcomeLabel := widget.NewLabel(welcomeText)
|
||||
welcomeLabel.Wrapping = fyne.TextWrapWord
|
||||
|
||||
// Useful links section
|
||||
linksTitle := widget.NewLabel("Useful Links:")
|
||||
linksTitle.TextStyle = fyne.TextStyle{Bold: true}
|
||||
|
||||
// Create link buttons
|
||||
docsButton := widget.NewButton("📚 Documentation", func() {
|
||||
ui.openURL("https://localai.io/docs/")
|
||||
})
|
||||
|
||||
githubButton := widget.NewButton("🐙 GitHub Repository", func() {
|
||||
ui.openURL("https://github.com/mudler/LocalAI")
|
||||
})
|
||||
|
||||
modelsButton := widget.NewButton("🤖 Model Gallery", func() {
|
||||
ui.openURL("https://localai.io/models/")
|
||||
})
|
||||
|
||||
communityButton := widget.NewButton("💬 Community", func() {
|
||||
ui.openURL("https://discord.gg/XgwjKptP7Z")
|
||||
})
|
||||
|
||||
// Checkbox to disable welcome window
|
||||
dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) {
|
||||
if ui.launcher != nil {
|
||||
config := ui.launcher.GetConfig()
|
||||
v := !checked
|
||||
config.ShowWelcome = &v
|
||||
ui.launcher.SetConfig(config)
|
||||
}
|
||||
})
|
||||
|
||||
config := ui.launcher.GetConfig()
|
||||
if config.ShowWelcome != nil {
|
||||
dontShowAgainCheck.SetChecked(*config.ShowWelcome)
|
||||
}
|
||||
|
||||
// Close button
|
||||
closeButton := widget.NewButton("Get Started", func() {
|
||||
welcomeWindow.Close()
|
||||
})
|
||||
closeButton.Importance = widget.HighImportance
|
||||
|
||||
// Layout
|
||||
linksContainer := container.NewVBox(
|
||||
linksTitle,
|
||||
docsButton,
|
||||
githubButton,
|
||||
modelsButton,
|
||||
communityButton,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
titleLabel,
|
||||
widget.NewSeparator(),
|
||||
welcomeLabel,
|
||||
widget.NewSeparator(),
|
||||
linksContainer,
|
||||
widget.NewSeparator(),
|
||||
dontShowAgainCheck,
|
||||
widget.NewSeparator(),
|
||||
closeButton,
|
||||
)
|
||||
|
||||
welcomeWindow.SetContent(content)
|
||||
welcomeWindow.Show()
|
||||
})
|
||||
}
|
||||
|
||||
// openURL opens a URL in the default browser
|
||||
func (ui *LauncherUI) openURL(urlString string) {
|
||||
parsedURL, err := url.Parse(urlString)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL %s: %v", urlString, err)
|
||||
return
|
||||
}
|
||||
fyne.CurrentApp().OpenURL(parsedURL)
|
||||
}
|
||||
BIN
cmd/launcher/logo.png
Normal file
BIN
cmd/launcher/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.0 KiB |
92
cmd/launcher/main.go
Normal file
92
cmd/launcher/main.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/driver/desktop"
|
||||
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create the application with unique ID
|
||||
myApp := app.NewWithID("com.localai.launcher")
|
||||
myApp.SetIcon(resourceIconPng)
|
||||
myWindow := myApp.NewWindow("LocalAI Launcher")
|
||||
myWindow.Resize(fyne.NewSize(800, 600))
|
||||
|
||||
// Create the launcher UI
|
||||
ui := coreLauncher.NewLauncherUI()
|
||||
|
||||
// Initialize the launcher with UI context
|
||||
launcher := coreLauncher.NewLauncher(ui, myWindow, myApp)
|
||||
|
||||
// Setup the UI
|
||||
content := ui.CreateMainUI(launcher)
|
||||
myWindow.SetContent(content)
|
||||
|
||||
// Setup window close behavior - minimize to tray instead of closing
|
||||
myWindow.SetCloseIntercept(func() {
|
||||
myWindow.Hide()
|
||||
})
|
||||
|
||||
// Setup system tray using Fyne's built-in approach``
|
||||
if desk, ok := myApp.(desktop.App); ok {
|
||||
// Create a dynamic systray manager
|
||||
systray := coreLauncher.NewSystrayManager(launcher, myWindow, desk, myApp, resourceIconPng)
|
||||
launcher.SetSystray(systray)
|
||||
}
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
setupSignalHandling(launcher)
|
||||
|
||||
// Initialize the launcher state
|
||||
go func() {
|
||||
if err := launcher.Initialize(); err != nil {
|
||||
log.Printf("Failed to initialize launcher: %v", err)
|
||||
if launcher.GetUI() != nil {
|
||||
launcher.GetUI().UpdateStatus("Failed to initialize: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
// Load configuration into UI
|
||||
launcher.GetUI().LoadConfiguration()
|
||||
launcher.GetUI().UpdateStatus("Ready")
|
||||
|
||||
// Show welcome window if configured to do so
|
||||
config := launcher.GetConfig()
|
||||
if *config.ShowWelcome {
|
||||
launcher.GetUI().ShowWelcomeWindow()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Run the application in background (window only shown when "Settings" is clicked)
|
||||
myApp.Run()
|
||||
}
|
||||
|
||||
// setupSignalHandling sets up signal handlers for graceful shutdown
|
||||
func setupSignalHandling(launcher *coreLauncher.Launcher) {
|
||||
// Create a channel to receive OS signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
// Register for interrupt and terminate signals
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Handle signals in a separate goroutine
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
// Perform cleanup
|
||||
if err := launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
|
||||
// Exit the application
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -2,9 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
"github.com/joho/godotenv"
|
||||
@@ -24,15 +22,7 @@ func main() {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
|
||||
// Catch signals from the OS requesting us to exit
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
// handle loading environment variabled from .env files
|
||||
// handle loading environment variables from .env files
|
||||
envFiles := []string{".env", "localai.env"}
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
@@ -42,7 +32,7 @@ func main() {
|
||||
|
||||
for _, envFile := range envFiles {
|
||||
if _, err := os.Stat(envFile); err == nil {
|
||||
log.Info().Str("envFile", envFile).Msg("env file found, loading environment variables from file")
|
||||
log.Debug().Str("envFile", envFile).Msg("env file found, loading environment variables from file")
|
||||
err = godotenv.Load(envFile)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("envFile", envFile).Msg("failed to load environment variables from file")
|
||||
@@ -97,19 +87,19 @@ Version: ${version}
|
||||
switch *cli.CLI.LogLevel {
|
||||
case "error":
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
log.Info().Msg("Setting logging to error")
|
||||
log.Debug().Msg("Setting logging to error")
|
||||
case "warn":
|
||||
zerolog.SetGlobalLevel(zerolog.WarnLevel)
|
||||
log.Info().Msg("Setting logging to warn")
|
||||
log.Debug().Msg("Setting logging to warn")
|
||||
case "info":
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
log.Info().Msg("Setting logging to info")
|
||||
log.Debug().Msg("Setting logging to info")
|
||||
case "debug":
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
log.Debug().Msg("Setting logging to debug")
|
||||
case "trace":
|
||||
zerolog.SetGlobalLevel(zerolog.TraceLevel)
|
||||
log.Trace().Msg("Setting logging to trace")
|
||||
log.Debug().Msg("Setting logging to trace")
|
||||
}
|
||||
|
||||
// Run the thing!
|
||||
@@ -56,12 +56,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
@@ -87,13 +87,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
|
||||
@@ -78,6 +78,12 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
b = c.Batch
|
||||
}
|
||||
|
||||
flashAttention := "auto"
|
||||
|
||||
if c.FlashAttention != nil {
|
||||
flashAttention = *c.FlashAttention
|
||||
}
|
||||
|
||||
f16 := false
|
||||
if c.F16 != nil {
|
||||
f16 = *c.F16
|
||||
@@ -166,7 +172,7 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
|
||||
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
|
||||
MMProj: c.MMProj,
|
||||
FlashAttention: c.FlashAttention,
|
||||
FlashAttention: flashAttention,
|
||||
CacheTypeKey: c.CacheTypeK,
|
||||
CacheTypeValue: c.CacheTypeV,
|
||||
NoKVOffload: c.NoKVOffloading,
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
@@ -34,6 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
Diarize: diarize,
|
||||
Threads: uint32(*modelConfig.Threads),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
@@ -22,12 +22,18 @@ func VideoGeneration(height, width int32, prompt, startImage, endImage, dst stri
|
||||
_, err := inferenceModel.GenerateVideo(
|
||||
appConfig.Context,
|
||||
&proto.GenerateVideoRequest{
|
||||
Height: height,
|
||||
Width: width,
|
||||
Prompt: prompt,
|
||||
StartImage: startImage,
|
||||
EndImage: endImage,
|
||||
Dst: dst,
|
||||
Height: height,
|
||||
Width: width,
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
StartImage: startImage,
|
||||
EndImage: endImage,
|
||||
NumFrames: numFrames,
|
||||
Fps: fps,
|
||||
Seed: seed,
|
||||
CfgScale: cfgScale,
|
||||
Step: step,
|
||||
Dst: dst,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
@@ -100,7 +101,8 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
)
|
||||
@@ -45,5 +46,7 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.Handler(nil)
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
)
|
||||
|
||||
@@ -19,5 +20,7 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
||||
|
||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
||||
|
||||
signals.Handler(nil)
|
||||
|
||||
return fs.Start(context.Background())
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
@@ -125,7 +126,8 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
||||
}
|
||||
|
||||
err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -73,9 +75,16 @@ type RunCMD struct {
|
||||
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
|
||||
MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
|
||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.Version {
|
||||
fmt.Println(internal.Version)
|
||||
return nil
|
||||
}
|
||||
|
||||
os.MkdirAll(r.BackendsPath, 0750)
|
||||
os.MkdirAll(r.ModelsPath, 0750)
|
||||
|
||||
@@ -216,5 +225,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Catch signals from the OS requesting us to exit, and stop all backends
|
||||
signals.Handler(app.ModelLoader())
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user