mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 11:13:31 -05:00
Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2e8b6d1aa | ||
|
|
fba5b557a1 | ||
|
|
6db19c5cb9 | ||
|
|
5428678209 | ||
|
|
06129139eb | ||
|
|
05757e2738 | ||
|
|
240b790f29 | ||
|
|
5f221f5946 | ||
|
|
def7cdc0bf | ||
|
|
ea9bf3dba2 | ||
|
|
b8eca530b6 | ||
|
|
47034ddacd | ||
|
|
9a41331855 | ||
|
|
facc0181df | ||
|
|
4733adb983 | ||
|
|
326fda3223 | ||
|
|
abf61e5b42 | ||
|
|
2ae45e7635 | ||
|
|
7d41551e10 | ||
|
|
6fbd720515 | ||
|
|
4e40a8d1ed | ||
|
|
003b9292fe | ||
|
|
09457b9221 | ||
|
|
41aa7e107f | ||
|
|
bda875f962 | ||
|
|
224063f0f7 | ||
|
|
89978c8b57 | ||
|
|
987b5dcac1 | ||
|
|
ec1276e5a9 | ||
|
|
61ba98d43d | ||
|
|
b9a25b16e6 | ||
|
|
6a8149e1fd | ||
|
|
9c2840ac38 | ||
|
|
20a70e1244 | ||
|
|
3295a298f4 | ||
|
|
da6f37f000 | ||
|
|
c092633cd7 | ||
|
|
7e2a522229 | ||
|
|
03e8592450 | ||
|
|
f207bd1427 | ||
|
|
a5c0fe31c3 | ||
|
|
c68907ac65 | ||
|
|
9087ddc4de | ||
|
|
33bebd5114 | ||
|
|
2913676157 | ||
|
|
e83652489c | ||
|
|
d6274eaf4a | ||
|
|
4d90971424 | ||
|
|
90f5639639 | ||
|
|
a35a701052 | ||
|
|
3d8ec72dbf | ||
|
|
2a9d675d62 | ||
|
|
c782e8abf1 | ||
|
|
a1e1942d83 | ||
|
|
787302b204 | ||
|
|
0b085089b9 | ||
|
|
624f3b1fc8 | ||
|
|
c07bc55fee | ||
|
|
173e0774c0 | ||
|
|
8ece26ab7c | ||
|
|
d704cc7970 | ||
|
|
ab17baaae1 | ||
|
|
ca358fcdca | ||
|
|
9aadfd485f | ||
|
|
da3b0850de | ||
|
|
8b1e8b4cda | ||
|
|
3d22bfc27c | ||
|
|
4438b4361e | ||
|
|
04bad9a2da | ||
|
|
8235e53602 | ||
|
|
eb5c3670f1 | ||
|
|
89e61fca90 | ||
|
|
9d6efe8842 | ||
|
|
60726d16f2 | ||
|
|
9d7ec09ec0 | ||
|
|
36179ffbed | ||
|
|
d25145e641 | ||
|
|
949e5b9be8 | ||
|
|
73ecb7f90b | ||
|
|
053bed6e5f | ||
|
|
932360bf7e | ||
|
|
6d0b52843f | ||
|
|
078c22f485 | ||
|
|
6ef3852de5 | ||
|
|
a8057b952c | ||
|
|
fd5c1d916f | ||
|
|
5ce982b9c9 | ||
|
|
47ccfccf7a | ||
|
|
a760f7ff39 | ||
|
|
facf7625f3 | ||
|
|
b3600b3c50 | ||
|
|
f0b47cfe6a |
290
.github/workflows/backend.yml
vendored
290
.github/workflows/backend.yml
vendored
@@ -87,6 +87,18 @@ jobs:
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-l4t-diffusers'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# CUDA 11 additional backends
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "11"
|
||||
@@ -313,7 +325,7 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-transformers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "transformers"
|
||||
@@ -325,7 +337,7 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-diffusers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "diffusers"
|
||||
@@ -338,7 +350,7 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-kokoro'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "kokoro"
|
||||
@@ -374,31 +386,19 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-bark'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "bark"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# sycl builds
|
||||
- build-type: 'sycl_f32'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-rerankers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "rerankers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-rerankers'
|
||||
tag-suffix: '-gpu-intel-rerankers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
@@ -429,60 +429,36 @@ jobs:
|
||||
backend: "llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
- build-type: 'sycl_f32'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-vllm'
|
||||
tag-suffix: '-gpu-intel-vllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "vllm"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-vllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "vllm"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-transformers'
|
||||
tag-suffix: '-gpu-intel-transformers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "transformers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-transformers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "transformers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-diffusers'
|
||||
tag-suffix: '-gpu-intel-diffusers'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
@@ -490,96 +466,48 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# SYCL additional backends
|
||||
- build-type: 'sycl_f32'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-kokoro'
|
||||
tag-suffix: '-gpu-intel-kokoro'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "kokoro"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-kokoro'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "kokoro"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-faster-whisper'
|
||||
tag-suffix: '-gpu-intel-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-coqui'
|
||||
tag-suffix: '-gpu-intel-coqui'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "coqui"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-coqui'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "coqui"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-bark'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "bark"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-bark'
|
||||
tag-suffix: '-gpu-intel-bark'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
@@ -868,7 +796,155 @@ jobs:
|
||||
skip-drivers: 'false'
|
||||
backend: "huggingface"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
context: "./"
|
||||
# rfdetr
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-rfdetr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-rfdetr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "11"
|
||||
cuda-minor-version: "7"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-11-rfdetr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-rfdetr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-rfdetr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# exllama2
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-exllama2'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-exllama2'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "11"
|
||||
cuda-minor-version: "7"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-11-exllama2'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-exllama2'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
skip-drivers: 'false'
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-hipblas-exllama2'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# runs out of space on the runner
|
||||
# - build-type: 'hipblas'
|
||||
# cuda-major-version: ""
|
||||
# cuda-minor-version: ""
|
||||
# platforms: 'linux/amd64'
|
||||
# tag-latest: 'auto'
|
||||
# tag-suffix: '-gpu-hipblas-rfdetr'
|
||||
# base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||
# runs-on: 'ubuntu-latest'
|
||||
# skip-drivers: 'false'
|
||||
# backend: "rfdetr"
|
||||
# dockerfile: "./backend/Dockerfile.python"
|
||||
# context: "./backend"
|
||||
# kitten-tts
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-kitten-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "kitten-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
llama-cpp-darwin:
|
||||
runs-on: macOS-14
|
||||
strategy:
|
||||
@@ -904,6 +980,7 @@ jobs:
|
||||
path: build/llama-cpp.tar
|
||||
llama-cpp-darwin-publish:
|
||||
needs: llama-cpp-darwin
|
||||
if: github.event_name != 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download llama-cpp.tar
|
||||
@@ -992,6 +1069,7 @@ jobs:
|
||||
name: llama-cpp-tar-x86
|
||||
path: build/llama-cpp.tar
|
||||
llama-cpp-darwin-x86-publish:
|
||||
if: github.event_name != 'pull_request'
|
||||
needs: llama-cpp-darwin-x86
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -1045,4 +1123,4 @@ jobs:
|
||||
run: |
|
||||
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
|
||||
crane push llama-cpp.tar $tag
|
||||
done
|
||||
done
|
||||
|
||||
6
.github/workflows/image-pr.yml
vendored
6
.github/workflows/image-pr.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'false'
|
||||
tag-suffix: '-gpu-nvidia-cuda12'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
makeflags: "--jobs=3 --output-sync=target"
|
||||
@@ -51,12 +51,12 @@ jobs:
|
||||
grpc-base-image: "ubuntu:22.04"
|
||||
runs-on: 'ubuntu-latest'
|
||||
makeflags: "--jobs=3 --output-sync=target"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'sycl'
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'false'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
grpc-base-image: "ubuntu:22.04"
|
||||
tag-suffix: 'sycl-f16'
|
||||
tag-suffix: 'sycl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
makeflags: "--jobs=3 --output-sync=target"
|
||||
- build-type: 'vulkan'
|
||||
|
||||
21
.github/workflows/image.yml
vendored
21
.github/workflows/image.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
cuda-minor-version: "7"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda11'
|
||||
tag-suffix: '-gpu-nvidia-cuda-11'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
makeflags: "--jobs=4 --output-sync=target"
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda12'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
@@ -103,30 +103,21 @@ jobs:
|
||||
- build-type: 'vulkan'
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-vulkan'
|
||||
tag-suffix: '-gpu-vulkan'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
makeflags: "--jobs=4 --output-sync=target"
|
||||
aio: "-aio-gpu-vulkan"
|
||||
- build-type: 'sycl_f16'
|
||||
- build-type: 'intel'
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
grpc-base-image: "ubuntu:22.04"
|
||||
tag-suffix: '-gpu-intel-f16'
|
||||
tag-suffix: '-gpu-intel'
|
||||
runs-on: 'ubuntu-latest'
|
||||
makeflags: "--jobs=3 --output-sync=target"
|
||||
aio: "-aio-gpu-intel-f16"
|
||||
- build-type: 'sycl_f32'
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||
grpc-base-image: "ubuntu:22.04"
|
||||
tag-suffix: '-gpu-intel-f32'
|
||||
runs-on: 'ubuntu-latest'
|
||||
makeflags: "--jobs=3 --output-sync=target"
|
||||
aio: "-aio-gpu-intel-f32"
|
||||
aio: "-aio-gpu-intel"
|
||||
|
||||
gh-runner:
|
||||
uses: ./.github/workflows/image_build.yml
|
||||
|
||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -23,6 +23,20 @@ jobs:
|
||||
matrix:
|
||||
go-version: ['1.21.x']
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: true
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Release space from worker
|
||||
run: |
|
||||
echo "Listing top largest packages"
|
||||
|
||||
10
Dockerfile
10
Dockerfile
@@ -9,7 +9,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl wget espeak-ng libgomp1 \
|
||||
python3 python-is-python3 ffmpeg && \
|
||||
python3 python-is-python3 ffmpeg libopenblas-base libopenblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -72,6 +72,12 @@ RUN <<EOT bash
|
||||
fi
|
||||
EOT
|
||||
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
echo "nvidia-l4t" > /run/localai/capability
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
@@ -94,6 +100,8 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
|
||||
52
Makefile
52
Makefile
@@ -5,8 +5,6 @@ BINARY_NAME=local-ai
|
||||
|
||||
GORELEASER?=
|
||||
|
||||
ONEAPI_VERSION?=2025.2
|
||||
|
||||
export BUILD_TYPE?=
|
||||
|
||||
GO_TAGS?=
|
||||
@@ -134,6 +132,9 @@ test: test-models/testmodel.ggml protogen-go
|
||||
$(MAKE) test-tts
|
||||
$(MAKE) test-stablediffusion
|
||||
|
||||
backends/diffusers: docker-build-diffusers docker-save-diffusers build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)"
|
||||
|
||||
backends/llama-cpp: docker-build-llama-cpp docker-save-llama-cpp build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
||||
|
||||
@@ -155,6 +156,15 @@ backends/local-store: docker-build-local-store docker-save-local-store build
|
||||
backends/huggingface: docker-build-huggingface docker-save-huggingface build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/huggingface.tar)"
|
||||
|
||||
backends/rfdetr: docker-build-rfdetr docker-save-rfdetr build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/rfdetr.tar)"
|
||||
|
||||
backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kitten-tts.tar)"
|
||||
|
||||
backends/kokoro: docker-build-kokoro docker-save-kokoro build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
|
||||
|
||||
########################################################
|
||||
## AIO tests
|
||||
########################################################
|
||||
@@ -322,7 +332,7 @@ docker-cuda11:
|
||||
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
||||
-t $(DOCKER_IMAGE)-cuda11 .
|
||||
-t $(DOCKER_IMAGE)-cuda-11 .
|
||||
|
||||
docker-aio:
|
||||
@echo "Building AIO image with base $(BASE_IMAGE) as $(DOCKER_AIO_IMAGE)"
|
||||
@@ -337,19 +347,11 @@ docker-aio-all:
|
||||
|
||||
docker-image-intel:
|
||||
docker build \
|
||||
--build-arg BASE_IMAGE=intel/oneapi-basekit:${ONEAPI_VERSION}.0-0-devel-ubuntu24.04 \
|
||||
--build-arg BASE_IMAGE=quay.io/go-skynet/intel-oneapi-base:latest \
|
||||
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||
|
||||
docker-image-intel-xpu:
|
||||
docker build \
|
||||
--build-arg BASE_IMAGE=intel/oneapi-basekit:${ONEAPI_VERSION}.0-0-devel-ubuntu22.04 \
|
||||
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||
--build-arg BUILD_TYPE=intel -t $(DOCKER_IMAGE) .
|
||||
|
||||
########################################################
|
||||
## Backends
|
||||
@@ -373,6 +375,24 @@ docker-build-local-store:
|
||||
docker-build-huggingface:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:huggingface -f backend/Dockerfile.golang --build-arg BACKEND=huggingface .
|
||||
|
||||
docker-build-rfdetr:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:rfdetr -f backend/Dockerfile.python --build-arg BACKEND=rfdetr ./backend
|
||||
|
||||
docker-build-kitten-tts:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kitten-tts -f backend/Dockerfile.python --build-arg BACKEND=kitten-tts ./backend
|
||||
|
||||
docker-save-kitten-tts: backend-images
|
||||
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
||||
|
||||
docker-build-kokoro:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
||||
|
||||
docker-save-kokoro: backend-images
|
||||
docker save local-ai-backend:kokoro -o backend-images/kokoro.tar
|
||||
|
||||
docker-save-rfdetr: backend-images
|
||||
docker save local-ai-backend:rfdetr -o backend-images/rfdetr.tar
|
||||
|
||||
docker-save-huggingface: backend-images
|
||||
docker save local-ai-backend:huggingface -o backend-images/huggingface.tar
|
||||
|
||||
@@ -410,10 +430,10 @@ docker-build-transformers:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:transformers -f backend/Dockerfile.python --build-arg BACKEND=transformers .
|
||||
|
||||
docker-build-diffusers:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:diffusers -f backend/Dockerfile.python --build-arg BACKEND=diffusers .
|
||||
docker build --progress=plain --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:diffusers -f backend/Dockerfile.python --build-arg BACKEND=diffusers ./backend
|
||||
|
||||
docker-build-kokoro:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro .
|
||||
docker-save-diffusers: backend-images
|
||||
docker save local-ai-backend:diffusers -o backend-images/diffusers.tar
|
||||
|
||||
docker-build-whisper:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
|
||||
|
||||
12
README.md
12
README.md
@@ -140,11 +140,7 @@ docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri
|
||||
### Intel GPU Images (oneAPI):
|
||||
|
||||
```bash
|
||||
# Intel GPU with FP16 support
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel-f16
|
||||
|
||||
# Intel GPU with FP32 support
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel-f32
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
||||
```
|
||||
|
||||
### Vulkan GPU Images:
|
||||
@@ -166,7 +162,7 @@ docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-ai
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-aio-gpu-nvidia-cuda-11
|
||||
|
||||
# Intel GPU version
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-gpu-intel-f16
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-gpu-intel
|
||||
|
||||
# AMD GPU version
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-aio-gpu-hipblas
|
||||
@@ -189,10 +185,13 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
||||
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
|
||||
|
||||
## 📰 Latest project news
|
||||
|
||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
||||
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
||||
@@ -226,6 +225,7 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- [Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||
|
||||
@@ -96,17 +96,6 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
# Intel oneAPI requirements
|
||||
RUN <<EOT bash
|
||||
if [[ "${BUILD_TYPE}" == sycl* ]] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
intel-oneapi-runtime-libs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install Go
|
||||
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
|
||||
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin:/usr/local/bin
|
||||
|
||||
@@ -20,6 +20,7 @@ service Backend {
|
||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
||||
|
||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||
@@ -304,6 +305,9 @@ message GenerateImageRequest {
|
||||
// Diffusers
|
||||
string EnableParameters = 10;
|
||||
int32 CLIPSkip = 11;
|
||||
|
||||
// Reference images for models that support them (e.g., Flux Kontext)
|
||||
repeated string ref_images = 12;
|
||||
}
|
||||
|
||||
message GenerateVideoRequest {
|
||||
@@ -376,3 +380,20 @@ message Message {
|
||||
string role = 1;
|
||||
string content = 2;
|
||||
}
|
||||
|
||||
message DetectOptions {
|
||||
string src = 1;
|
||||
}
|
||||
|
||||
message Detection {
|
||||
float x = 1;
|
||||
float y = 2;
|
||||
float width = 3;
|
||||
float height = 4;
|
||||
float confidence = 5;
|
||||
string class_name = 6;
|
||||
}
|
||||
|
||||
message DetectResponse {
|
||||
repeated Detection Detections = 1;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=3f4fc97f1d745f1d5d3c853949503136d419e6de
|
||||
LLAMA_VERSION?=be48528b068111304e4a0bb82c028558b5705f05
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
@@ -26,7 +26,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
||||
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
|
||||
@@ -313,9 +313,11 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
|
||||
|
||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
||||
else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
||||
|
||||
if ( request->yarnextfactor() != 0.0f ) {
|
||||
params.yarn_ext_factor = request->yarnextfactor();
|
||||
}
|
||||
|
||||
@@ -19,10 +19,10 @@ LD_FLAGS?=
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=eed97a5e1d054f9c1e7ac01982ae480411d4157e
|
||||
STABLEDIFFUSION_GGML_VERSION?=5900ef6605c6fbf7934239f795c13c97bc993853
|
||||
|
||||
# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DGGML_MAX_NAME=128 -DSD_USE_SYSTEM_GGML=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#define GGML_MAX_NAME 128
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
@@ -5,6 +7,7 @@
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include "gosd.h"
|
||||
|
||||
// #include "preprocessing.hpp"
|
||||
@@ -85,7 +88,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(char *model, char* options[], int threads, int diff) {
|
||||
int load_model(char *model, char *model_path, char* options[], int threads, int diff) {
|
||||
fprintf (stderr, "Loading model!\n");
|
||||
|
||||
sd_set_log_callback(sd_log_cb, NULL);
|
||||
@@ -103,6 +106,8 @@ int load_model(char *model, char* options[], int threads, int diff) {
|
||||
char *vae_path = "";
|
||||
char *scheduler = "";
|
||||
char *sampler = "";
|
||||
char *lora_dir = model_path;
|
||||
bool lora_dir_allocated = false;
|
||||
|
||||
fprintf(stderr, "parsing options\n");
|
||||
|
||||
@@ -132,6 +137,20 @@ int load_model(char *model, char* options[], int threads, int diff) {
|
||||
if (!strcmp(optname, "sampler")) {
|
||||
sampler = optval;
|
||||
}
|
||||
if (!strcmp(optname, "lora_dir")) {
|
||||
// Path join with model dir
|
||||
if (model_path && strlen(model_path) > 0) {
|
||||
std::filesystem::path model_path_str(model_path);
|
||||
std::filesystem::path lora_path(optval);
|
||||
std::filesystem::path full_lora_path = model_path_str / lora_path;
|
||||
lora_dir = strdup(full_lora_path.string().c_str());
|
||||
lora_dir_allocated = true;
|
||||
fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir);
|
||||
} else {
|
||||
lora_dir = optval;
|
||||
fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "parsed options\n");
|
||||
@@ -176,7 +195,7 @@ int load_model(char *model, char* options[], int threads, int diff) {
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.taesd_path = "";
|
||||
ctx_params.control_net_path = "";
|
||||
ctx_params.lora_model_dir = "";
|
||||
ctx_params.lora_model_dir = lora_dir;
|
||||
ctx_params.embedding_dir = "";
|
||||
ctx_params.stacked_id_embed_dir = "";
|
||||
ctx_params.vae_decode_only = false;
|
||||
@@ -189,16 +208,25 @@ int load_model(char *model, char* options[], int threads, int diff) {
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
fprintf (stderr, "failed loading model (generic error)\n");
|
||||
// Clean up allocated memory
|
||||
if (lora_dir_allocated && lora_dir) {
|
||||
free(lora_dir);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
fprintf (stderr, "Created context: OK\n");
|
||||
|
||||
sd_c = sd_ctx;
|
||||
|
||||
// Clean up allocated memory
|
||||
if (lora_dir_allocated && lora_dir) {
|
||||
free(lora_dir);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
||||
|
||||
sd_image_t* results;
|
||||
|
||||
@@ -221,15 +249,187 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
p.seed = seed;
|
||||
p.input_id_images_path = "";
|
||||
|
||||
// Handle input image for img2img
|
||||
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
||||
bool has_mask_image = (mask_image != NULL && strlen(mask_image) > 0);
|
||||
|
||||
uint8_t* input_image_buffer = NULL;
|
||||
uint8_t* mask_image_buffer = NULL;
|
||||
std::vector<uint8_t> default_mask_image_vec;
|
||||
|
||||
if (has_input_image) {
|
||||
fprintf(stderr, "Loading input image: %s\n", src_image);
|
||||
|
||||
int c = 0;
|
||||
int img_width = 0;
|
||||
int img_height = 0;
|
||||
input_image_buffer = stbi_load(src_image, &img_width, &img_height, &c, 3);
|
||||
if (input_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load input image from '%s'\n", src_image);
|
||||
return 1;
|
||||
}
|
||||
if (c < 3) {
|
||||
fprintf(stderr, "Input image must have at least 3 channels, got %d\n", c);
|
||||
free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Resize input image if dimensions don't match
|
||||
if (img_width != width || img_height != height) {
|
||||
fprintf(stderr, "Resizing input image from %dx%d to %dx%d\n", img_width, img_height, width, height);
|
||||
|
||||
uint8_t* resized_image_buffer = (uint8_t*)malloc(height * width * 3);
|
||||
if (resized_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized image\n");
|
||||
free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
stbir_resize(input_image_buffer, img_width, img_height, 0,
|
||||
resized_image_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(input_image_buffer);
|
||||
input_image_buffer = resized_image_buffer;
|
||||
}
|
||||
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
||||
p.strength = strength;
|
||||
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
||||
} else {
|
||||
// No input image, use empty image for text-to-image
|
||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
||||
p.strength = 0.0f;
|
||||
}
|
||||
|
||||
// Handle mask image for inpainting
|
||||
if (has_mask_image) {
|
||||
fprintf(stderr, "Loading mask image: %s\n", mask_image);
|
||||
|
||||
int c = 0;
|
||||
int mask_width = 0;
|
||||
int mask_height = 0;
|
||||
mask_image_buffer = stbi_load(mask_image, &mask_width, &mask_height, &c, 1);
|
||||
if (mask_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load mask image from '%s'\n", mask_image);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Resize mask if dimensions don't match
|
||||
if (mask_width != width || mask_height != height) {
|
||||
fprintf(stderr, "Resizing mask image from %dx%d to %dx%d\n", mask_width, mask_height, width, height);
|
||||
|
||||
uint8_t* resized_mask_buffer = (uint8_t*)malloc(height * width);
|
||||
if (resized_mask_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized mask\n");
|
||||
free(mask_image_buffer);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
stbir_resize(mask_image_buffer, mask_width, mask_height, 0,
|
||||
resized_mask_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
1, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(mask_image_buffer);
|
||||
mask_image_buffer = resized_mask_buffer;
|
||||
}
|
||||
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
||||
fprintf(stderr, "Using inpainting with mask\n");
|
||||
} else {
|
||||
// No mask image, create default full mask
|
||||
default_mask_image_vec.resize(width * height, 255);
|
||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
std::vector<sd_image_t> ref_images_vec;
|
||||
std::vector<uint8_t*> ref_image_buffers;
|
||||
|
||||
if (ref_images_count > 0 && ref_images != NULL) {
|
||||
fprintf(stderr, "Loading %d reference images\n", ref_images_count);
|
||||
|
||||
for (int i = 0; i < ref_images_count; i++) {
|
||||
if (ref_images[i] == NULL || strlen(ref_images[i]) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Loading reference image %d: %s\n", i + 1, ref_images[i]);
|
||||
|
||||
int c = 0;
|
||||
int ref_width = 0;
|
||||
int ref_height = 0;
|
||||
uint8_t* ref_image_buffer = stbi_load(ref_images[i], &ref_width, &ref_height, &c, 3);
|
||||
if (ref_image_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to load reference image from '%s'\n", ref_images[i]);
|
||||
continue;
|
||||
}
|
||||
if (c < 3) {
|
||||
fprintf(stderr, "Reference image must have at least 3 channels, got %d\n", c);
|
||||
free(ref_image_buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resize reference image if dimensions don't match
|
||||
if (ref_width != width || ref_height != height) {
|
||||
fprintf(stderr, "Resizing reference image from %dx%d to %dx%d\n", ref_width, ref_height, width, height);
|
||||
|
||||
uint8_t* resized_ref_buffer = (uint8_t*)malloc(height * width * 3);
|
||||
if (resized_ref_buffer == NULL) {
|
||||
fprintf(stderr, "Failed to allocate memory for resized reference image\n");
|
||||
free(ref_image_buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
stbir_resize(ref_image_buffer, ref_width, ref_height, 0,
|
||||
resized_ref_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
|
||||
free(ref_image_buffer);
|
||||
ref_image_buffer = resized_ref_buffer;
|
||||
}
|
||||
|
||||
ref_image_buffers.push_back(ref_image_buffer);
|
||||
ref_images_vec.push_back({(uint32_t)width, (uint32_t)height, 3, ref_image_buffer});
|
||||
}
|
||||
|
||||
if (!ref_images_vec.empty()) {
|
||||
p.ref_images = ref_images_vec.data();
|
||||
p.ref_images_count = ref_images_vec.size();
|
||||
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
||||
}
|
||||
}
|
||||
|
||||
results = generate_image(sd_c, &p);
|
||||
|
||||
if (results == NULL) {
|
||||
fprintf (stderr, "NO results\n");
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (results[0].data == NULL) {
|
||||
fprintf (stderr, "Results with no data\n");
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -245,11 +445,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
||||
results[0].data, 0, NULL);
|
||||
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
||||
|
||||
// TODO: free results. Why does it crash?
|
||||
|
||||
// Clean up
|
||||
free(results[0].data);
|
||||
results[0].data = NULL;
|
||||
free(results);
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
if (mask_image_buffer) free(mask_image_buffer);
|
||||
for (auto buffer : ref_image_buffers) {
|
||||
if (buffer) free(buffer);
|
||||
}
|
||||
fprintf (stderr, "gen_image is done", dst);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -29,16 +29,21 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
||||
|
||||
sd.threads = int(opts.Threads)
|
||||
|
||||
modelPath := opts.ModelPath
|
||||
|
||||
modelFile := C.CString(opts.ModelFile)
|
||||
defer C.free(unsafe.Pointer(modelFile))
|
||||
|
||||
modelPathC := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(modelPathC))
|
||||
|
||||
var options **C.char
|
||||
// prepare the options array to pass to C
|
||||
|
||||
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||
length := C.size_t(len(opts.Options))
|
||||
options = (**C.char)(C.malloc((length + 1) * size))
|
||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0:len(opts.Options) + 1:len(opts.Options) + 1]
|
||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0 : len(opts.Options)+1 : len(opts.Options)+1]
|
||||
|
||||
var diffusionModel int
|
||||
|
||||
@@ -70,7 +75,7 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
||||
|
||||
sd.cfgScale = opts.CFGScale
|
||||
|
||||
ret := C.load_model(modelFile, options, C.int(opts.Threads), C.int(diffusionModel))
|
||||
ret := C.load_model(modelFile, modelPathC, options, C.int(opts.Threads), C.int(diffusionModel))
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("could not load model")
|
||||
}
|
||||
@@ -88,7 +93,56 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
negative := C.CString(opts.NegativePrompt)
|
||||
defer C.free(unsafe.Pointer(negative))
|
||||
|
||||
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale))
|
||||
// Handle source image path
|
||||
var srcImage *C.char
|
||||
if opts.Src != "" {
|
||||
srcImage = C.CString(opts.Src)
|
||||
defer C.free(unsafe.Pointer(srcImage))
|
||||
}
|
||||
|
||||
// Handle mask image path
|
||||
var maskImage *C.char
|
||||
if opts.EnableParameters != "" {
|
||||
// Parse EnableParameters for mask path if provided
|
||||
// This is a simple approach - in a real implementation you might want to parse JSON
|
||||
if strings.Contains(opts.EnableParameters, "mask:") {
|
||||
parts := strings.Split(opts.EnableParameters, "mask:")
|
||||
if len(parts) > 1 {
|
||||
maskPath := strings.TrimSpace(parts[1])
|
||||
if maskPath != "" {
|
||||
maskImage = C.CString(maskPath)
|
||||
defer C.free(unsafe.Pointer(maskImage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
var refImages **C.char
|
||||
var refImagesCount C.int
|
||||
if len(opts.RefImages) > 0 {
|
||||
refImagesCount = C.int(len(opts.RefImages))
|
||||
// Allocate array of C strings
|
||||
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||
refImages = (**C.char)(C.malloc((C.size_t(len(opts.RefImages)) + 1) * size))
|
||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(refImages))[0 : len(opts.RefImages)+1 : len(opts.RefImages)+1]
|
||||
|
||||
for i, refImagePath := range opts.RefImages {
|
||||
view[i] = C.CString(refImagePath)
|
||||
defer C.free(unsafe.Pointer(view[i]))
|
||||
}
|
||||
view[len(opts.RefImages)] = nil
|
||||
}
|
||||
|
||||
// Default strength for img2img (0.75 is a good default)
|
||||
strength := C.float(0.75)
|
||||
if opts.Src != "" {
|
||||
// If we have a source image, use img2img mode
|
||||
// You could also parse strength from EnableParameters if needed
|
||||
strength = C.float(0.75)
|
||||
}
|
||||
|
||||
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale), srcImage, strength, maskImage, refImages, refImagesCount)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("inference failed")
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int load_model(char *model, char* options[], int threads, int diffusionModel);
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale);
|
||||
int load_model(char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -6,7 +6,7 @@ CMAKE_ARGS?=
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=7de8dd783f7b2eab56bff6bbc5d3369e34f0e77f
|
||||
WHISPER_CPP_VERSION?=b02242d0adb5c6c4896d59ac86d9ec9fe0d0fe33
|
||||
|
||||
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
|
||||
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,6 +57,11 @@ function init() {
|
||||
# - hipblas
|
||||
# - intel
|
||||
function getBuildProfile() {
|
||||
if [ "x${BUILD_TYPE}" == "xl4t" ]; then
|
||||
echo "l4t"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# First check if we are a cublas build, and if so report the correct build profile
|
||||
if [ x"${BUILD_TYPE}" == "xcublas" ]; then
|
||||
if [ ! -z ${CUDA_MAJOR_VERSION} ]; then
|
||||
@@ -111,7 +116,7 @@ function ensureVenv() {
|
||||
# - requirements-${BUILD_TYPE}.txt
|
||||
# - requirements-${BUILD_PROFILE}.txt
|
||||
#
|
||||
# BUILD_PROFILE is a pore specific version of BUILD_TYPE, ex: cuda11 or cuda12
|
||||
# BUILD_PROFILE is a pore specific version of BUILD_TYPE, ex: cuda-11 or cuda-12
|
||||
# it can also include some options that we do not have BUILD_TYPES for, ex: intel
|
||||
#
|
||||
# NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index.
|
||||
|
||||
@@ -8,4 +8,6 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
ensureVenv
|
||||
|
||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||
@@ -65,6 +65,19 @@ from diffusers.schedulers import (
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||
# Credits to https://github.com/neggles
|
||||
@@ -169,8 +182,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":")
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
self.options[key] = value
|
||||
|
||||
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
||||
if "torch_dtype" in self.options:
|
||||
if self.options["torch_dtype"] == "fp16":
|
||||
torchType = torch.float16
|
||||
elif self.options["torch_dtype"] == "bf16":
|
||||
torchType = torch.bfloat16
|
||||
elif self.options["torch_dtype"] == "fp32":
|
||||
torchType = torch.float32
|
||||
# remove it from options
|
||||
del self.options["torch_dtype"]
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
local = False
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
diffusers
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
torch==2.4.1
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.4.1+cu118
|
||||
diffusers
|
||||
torch==2.7.1+cu118
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
torch==2.4.1
|
||||
diffusers
|
||||
torch==2.7.1
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.3.1+rocm6.0
|
||||
torchvision==0.18.1+rocm6.0
|
||||
diffusers
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
torchvision==0.22.1+rocm6.3
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchvision==0.18.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
torchvision==0.20.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
diffusers
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
|
||||
10
backend/python/diffusers/requirements-l4t.txt
Normal file
10
backend/python/diffusers/requirements-l4t.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
diffusers
|
||||
transformers
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
optimum-quanto
|
||||
numpy<2
|
||||
sentencepiece
|
||||
29
backend/python/kitten-tts/Makefile
Normal file
29
backend/python/kitten-tts/Makefile
Normal file
@@ -0,0 +1,29 @@
|
||||
.PHONY: kitten-tts
|
||||
kitten-tts: protogen
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: protogen
|
||||
@echo "Running kitten-tts..."
|
||||
bash run.sh
|
||||
@echo "kitten-tts run."
|
||||
|
||||
.PHONY: test
|
||||
test: protogen
|
||||
@echo "Testing kitten-tts..."
|
||||
bash test.sh
|
||||
@echo "kitten-tts tested."
|
||||
|
||||
.PHONY: protogen
|
||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
backend_pb2_grpc.py backend_pb2.py:
|
||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
121
backend/python/kitten-tts/backend.py
Normal file
121
backend/python/kitten-tts/backend.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Kitten TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import torch
|
||||
from kittentts import KittenTTS
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
KITTEN_LANGUAGE = os.environ.get('KITTEN_LANGUAGE', None)
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
|
||||
# Get device
|
||||
# device = "cuda" if request.CUDA else "cpu"
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
self.AudioPath = None
|
||||
# List available KittenTTS models
|
||||
print("Available KittenTTS voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f")
|
||||
if os.path.isabs(request.AudioPath):
|
||||
self.AudioPath = request.AudioPath
|
||||
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
||||
# get base path of modelFile
|
||||
modelFileBase = os.path.dirname(request.ModelFile)
|
||||
# modify LoraAdapter to be relative to modelFileBase
|
||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||
|
||||
try:
|
||||
print("Preparing KittenTTS model, please wait", file=sys.stderr)
|
||||
# Use the model name from request.Model, defaulting to "KittenML/kitten-tts-nano-0.1" if not specified
|
||||
model_name = request.Model if request.Model else "KittenML/kitten-tts-nano-0.1"
|
||||
self.tts = KittenTTS(model_name)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# KittenTTS doesn't use language parameter like TTS, so we ignore it
|
||||
# For multi-speaker models, use voice parameter
|
||||
voice = request.voice if request.voice else "expr-voice-2-f"
|
||||
|
||||
# Generate audio using KittenTTS
|
||||
audio = self.tts.generate(request.text, voice=voice)
|
||||
|
||||
# Save the audio using soundfile
|
||||
sf.write(request.dst, audio, 24000)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
19
backend/python/kitten-tts/install.sh
Executable file
19
backend/python/kitten-tts/install.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
5
backend/python/kitten-tts/requirements.txt
Normal file
5
backend/python/kitten-tts/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl
|
||||
9
backend/python/kitten-tts/run.sh
Executable file
9
backend/python/kitten-tts/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
82
backend/python/kitten-tts/test.py
Normal file
82
backend/python/kitten-tts/test.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
||||
self.assertTrue(response.success)
|
||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertIsNotNone(tts_response)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
11
backend/python/kitten-tts/test.sh
Executable file
11
backend/python/kitten-tts/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -1,9 +1,18 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
.PHONY: kokoro
|
||||
kokoro: protogen
|
||||
bash install.sh
|
||||
$(MAKE) protogen
|
||||
|
||||
.PHONY: run
|
||||
run: protogen
|
||||
@echo "Running kokoro..."
|
||||
bash run.sh
|
||||
@echo "kokoro run."
|
||||
|
||||
.PHONY: test
|
||||
test: protogen
|
||||
@echo "Testing kokoro..."
|
||||
bash test.sh
|
||||
@echo "kokoro tested."
|
||||
|
||||
.PHONY: protogen
|
||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||
@@ -13,7 +22,7 @@ protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
backend_pb2_grpc.py backend_pb2.py:
|
||||
bash protogen.sh
|
||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
|
||||
23
backend/python/kokoro/README.md
Normal file
23
backend/python/kokoro/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Kokoro TTS Backend for LocalAI
|
||||
|
||||
This is a gRPC server backend for LocalAI that uses the Kokoro TTS pipeline.
|
||||
|
||||
## Creating a separate environment for kokoro project
|
||||
|
||||
```bash
|
||||
make kokoro
|
||||
```
|
||||
|
||||
## Testing the gRPC server
|
||||
|
||||
```bash
|
||||
make test
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- Lightweight TTS model with 82 million parameters
|
||||
- Apache-licensed weights
|
||||
- Fast and cost-efficient
|
||||
- Multi-language support
|
||||
- Multiple voice options
|
||||
115
backend/python/kokoro/backend.py
Executable file → Normal file
115
backend/python/kokoro/backend.py
Executable file → Normal file
@@ -1,101 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Extra gRPC server for Kokoro models.
|
||||
This is an extra gRPC server of LocalAI for Kokoro TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import torch
|
||||
from kokoro import KPipeline
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
|
||||
from models import build_model
|
||||
from kokoro import generate
|
||||
import torch
|
||||
|
||||
SAMPLE_RATE = 22050
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
KOKORO_LANG_CODE = os.environ.get('KOKORO_LANG_CODE', 'a')
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer for the backend service.
|
||||
|
||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the health status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Reply object that contains the health status of the backend service.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""
|
||||
A gRPC method that loads a model into memory.
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
Args:
|
||||
request: A LoadModelRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
Returns:
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
model_name = request.Model
|
||||
try:
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.MODEL = build_model(request.ModelFile, device)
|
||||
print("Preparing Kokoro TTS pipeline, please wait", file=sys.stderr)
|
||||
# empty dict
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
# Find the voice from the options, options are a list of strings in this form optname:optvalue:
|
||||
VOICE_NAME = None
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the images
|
||||
for opt in options:
|
||||
if opt.startswith("voice:"):
|
||||
VOICE_NAME = opt.split(":")[1]
|
||||
break
|
||||
if VOICE_NAME is None:
|
||||
return backend_pb2.Result(success=False, message=f"No voice specified in options")
|
||||
MODELPATH = request.ModelPath
|
||||
# If voice name contains a plus, split it and load the two models and combine them
|
||||
if "+" in VOICE_NAME:
|
||||
voice1, voice2 = VOICE_NAME.split("+")
|
||||
voice1 = torch.load(f'{MODELPATH}/{voice1}.pt', weights_only=True).to(device)
|
||||
voice2 = torch.load(f'{MODELPATH}/{voice2}.pt', weights_only=True).to(device)
|
||||
self.VOICEPACK = torch.mean(torch.stack([voice1, voice2]), dim=0)
|
||||
else:
|
||||
self.VOICEPACK = torch.load(f'{MODELPATH}/{VOICE_NAME}.pt', weights_only=True).to(device)
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":")
|
||||
self.options[key] = value
|
||||
|
||||
self.VOICE_NAME = VOICE_NAME
|
||||
|
||||
print(f'Loaded voice: {VOICE_NAME}')
|
||||
# Initialize Kokoro pipeline with language code
|
||||
lang_code = self.options.get("lang_code", KOKORO_LANG_CODE)
|
||||
self.pipeline = KPipeline(lang_code=lang_code)
|
||||
print(f"Kokoro TTS pipeline loaded with language code: {lang_code}", file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
return backend_pb2.Result(message="Kokoro TTS pipeline loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
model_name = request.model
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
try:
|
||||
audio, out_ps = generate(self.MODEL, request.text, self.VOICEPACK, lang=self.VOICE_NAME)
|
||||
print(out_ps)
|
||||
sf.write(request.dst, audio, SAMPLE_RATE)
|
||||
# Get voice from request, default to 'af_heart' if not specified
|
||||
voice = request.voice if request.voice else 'af_heart'
|
||||
|
||||
# Generate audio using Kokoro pipeline
|
||||
generator = self.pipeline(request.text, voice=voice)
|
||||
|
||||
# Get the first (and typically only) audio segment
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
# Save audio to the destination file
|
||||
sf.write(request.dst, audio, 24000)
|
||||
print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
|
||||
# For now, we only process the first segment
|
||||
# If you need to handle multiple segments, you might want to modify this
|
||||
break
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
@@ -108,11 +99,11 @@ def serve(address):
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("[Kokoro] Server started. Listening on: " + address, file=sys.stderr)
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("[Kokoro] Received termination signal. Shutting down...")
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
@@ -132,5 +123,5 @@ if __name__ == "__main__":
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(f"[Kokoro] startup: {args}", file=sys.stderr)
|
||||
|
||||
serve(args.addr)
|
||||
|
||||
@@ -1,524 +0,0 @@
|
||||
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/istftnet.py
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||
from scipy.signal import get_window
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
class AdaIN1d(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class AdaINResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||
super(AdaINResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.adain1 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
|
||||
self.adain2 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
|
||||
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||
|
||||
|
||||
def forward(self, x, s):
|
||||
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||
xt = n1(x, s)
|
||||
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||
xt = c1(xt)
|
||||
xt = n2(xt, s)
|
||||
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
class TorchSTFT(torch.nn.Module):
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
||||
|
||||
def transform(self, input_data):
|
||||
forward_transform = torch.stft(
|
||||
input_data,
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||
return_complex=True)
|
||||
|
||||
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
inverse_transform = torch.istft(
|
||||
magnitude * torch.exp(phase * 1j),
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||
|
||||
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
|
||||
class SineGen(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim)
|
||||
where dim indicates fundamental tone and overtones
|
||||
"""
|
||||
# convert to F0 in rad. The integer part n can be ignored
|
||||
# because 2 * np.pi * n doesn't affect phase
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
||||
device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
# # for normal case
|
||||
|
||||
# # To prevent torch.cumsum numerical overflow,
|
||||
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
||||
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
||||
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||
# cumsum_shift = torch.zeros_like(rad_values)
|
||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
|
||||
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||
scale_factor=1/self.upsample_scale,
|
||||
mode="linear").transpose(1, 2)
|
||||
|
||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||
# cumsum_shift = torch.zeros_like(rad_values)
|
||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
# voiced segments is sin(pi) or cos(0)
|
||||
# This is used for pulse-train generation
|
||||
|
||||
# identify the last time step in unvoiced segments
|
||||
uv = self._f02uv(f0_values)
|
||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
|
||||
# get the instantanouse phase
|
||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||
# different batch needs to be processed differently
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# stores the accumulation of i.phase within
|
||||
# each voiced segments
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
|
||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||
# within the previous voiced segment.
|
||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||
|
||||
# get the sines
|
||||
sines = torch.cos(i_phase * 2 * np.pi)
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
||||
device=f0.device)
|
||||
# fundamental component
|
||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
|
||||
# generate uv signal
|
||||
# uv = torch.ones(f0.shape)
|
||||
# uv = uv * (f0 > self.voiced_threshold)
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
def padDiff(x):
|
||||
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
resblock = AdaINResBlock1
|
||||
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=24000,
|
||||
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
||||
harmonic_num=8, voiced_threshod=10)
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.noise_res = nn.ModuleList()
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d, style_dim))
|
||||
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
|
||||
if i + 1 < len(upsample_rates): #
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(Conv1d(
|
||||
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
||||
else:
|
||||
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
||||
|
||||
|
||||
self.post_n_fft = gen_istft_n_fft
|
||||
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
||||
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||
|
||||
|
||||
def forward(self, x, s, f0):
|
||||
with torch.no_grad():
|
||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
|
||||
har_source, noi_source, uv = self.m_source(f0)
|
||||
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||
har_spec, har_phase = self.stft.transform(har_source)
|
||||
har = torch.cat([har_spec, har_phase], dim=1)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x_source = self.noise_convs[i](har)
|
||||
x_source = self.noise_res[i](x_source, s)
|
||||
|
||||
x = self.ups[i](x)
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||
return self.stft.inverse(spec, phase)
|
||||
|
||||
def fw_phase(self, x, s):
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.reflection_pad(x)
|
||||
x = self.conv_post(x)
|
||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||
return spec, phase
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class AdainResBlk1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||
upsample='none', dropout_p=0.0):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample_type = upsample
|
||||
self.upsample = UpSample1d(upsample)
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if upsample == 'none':
|
||||
self.pool = nn.Identity()
|
||||
else:
|
||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = self.upsample(x)
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(self.dropout(x))
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(self.dropout(x))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||
return out
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, layer_type):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_type == 'none':
|
||||
return x
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
||||
resblock_kernel_sizes = [3,7,11],
|
||||
upsample_rates = [10, 6],
|
||||
upsample_initial_channel=512,
|
||||
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
||||
upsample_kernel_sizes=[20, 12],
|
||||
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
||||
super().__init__()
|
||||
|
||||
self.decode = nn.ModuleList()
|
||||
|
||||
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||
|
||||
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
|
||||
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
|
||||
self.asr_res = nn.Sequential(
|
||||
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
||||
)
|
||||
|
||||
|
||||
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||
upsample_initial_channel, resblock_dilation_sizes,
|
||||
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
||||
|
||||
def forward(self, asr, F0_curve, N, s):
|
||||
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||
N = self.N_conv(N.unsqueeze(1))
|
||||
|
||||
x = torch.cat([asr, F0, N], axis=1)
|
||||
x = self.encode(x, s)
|
||||
|
||||
asr_res = self.asr_res(asr)
|
||||
|
||||
res = True
|
||||
for block in self.decode:
|
||||
if res:
|
||||
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||
x = block(x, s)
|
||||
if block.upsample_type != "none":
|
||||
res = False
|
||||
|
||||
x = self.generator(x, s, F0_curve)
|
||||
return x
|
||||
@@ -1,166 +0,0 @@
|
||||
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/kokoro.py
|
||||
import phonemizer
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def split_num(num):
|
||||
num = num.group()
|
||||
if '.' in num:
|
||||
return num
|
||||
elif ':' in num:
|
||||
h, m = [int(n) for n in num.split(':')]
|
||||
if m == 0:
|
||||
return f"{h} o'clock"
|
||||
elif m < 10:
|
||||
return f'{h} oh {m}'
|
||||
return f'{h} {m}'
|
||||
year = int(num[:4])
|
||||
if year < 1100 or year % 1000 < 10:
|
||||
return num
|
||||
left, right = num[:2], int(num[2:4])
|
||||
s = 's' if num.endswith('s') else ''
|
||||
if 100 <= year % 1000 <= 999:
|
||||
if right == 0:
|
||||
return f'{left} hundred{s}'
|
||||
elif right < 10:
|
||||
return f'{left} oh {right}{s}'
|
||||
return f'{left} {right}{s}'
|
||||
|
||||
def flip_money(m):
|
||||
m = m.group()
|
||||
bill = 'dollar' if m[0] == '$' else 'pound'
|
||||
if m[-1].isalpha():
|
||||
return f'{m[1:]} {bill}s'
|
||||
elif '.' not in m:
|
||||
s = '' if m[1:] == '1' else 's'
|
||||
return f'{m[1:]} {bill}{s}'
|
||||
b, c = m[1:].split('.')
|
||||
s = '' if b == '1' else 's'
|
||||
c = int(c.ljust(2, '0'))
|
||||
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
||||
return f'{b} {bill}{s} and {c} {coins}'
|
||||
|
||||
def point_num(num):
|
||||
a, b = num.group().split('.')
|
||||
return ' point '.join([a, ' '.join(b)])
|
||||
|
||||
def normalize_text(text):
|
||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||
text = text.replace('(', '«').replace(')', '»')
|
||||
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
||||
text = text.replace(a, b+' ')
|
||||
text = re.sub(r'[^\S \n]', ' ', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
||||
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
||||
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
||||
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
||||
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
||||
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
||||
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
||||
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
||||
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
||||
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
||||
text = re.sub(r'\d*\.\d+', point_num, text)
|
||||
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
||||
text = re.sub(r'(?<=\d)S', ' S', text)
|
||||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||
text = re.sub(r"(?<=X')S\b", 's', text)
|
||||
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
||||
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
||||
return text.strip()
|
||||
|
||||
def get_vocab():
|
||||
_pad = "$"
|
||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
dicts = {}
|
||||
for i in range(len((symbols))):
|
||||
dicts[symbols[i]] = i
|
||||
return dicts
|
||||
|
||||
VOCAB = get_vocab()
|
||||
def tokenize(ps):
|
||||
return [i for i in map(VOCAB.get, ps) if i is not None]
|
||||
|
||||
phonemizers = dict(
|
||||
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
||||
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
||||
)
|
||||
def phonemize(text, lang, norm=True):
|
||||
if norm:
|
||||
text = normalize_text(text)
|
||||
ps = phonemizers[lang].phonemize([text])
|
||||
ps = ps[0] if ps else ''
|
||||
# https://en.wiktionary.org/wiki/kokoro#English
|
||||
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
||||
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
||||
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
||||
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
||||
if lang == 'a':
|
||||
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
||||
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
||||
return ps.strip()
|
||||
|
||||
def length_to_mask(lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
device = ref_s.device
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||
c_frame += pred_dur[0,i].item()
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
|
||||
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||
ps = ps or phonemize(text, lang)
|
||||
tokens = tokenize(ps)
|
||||
if not tokens:
|
||||
return None
|
||||
elif len(tokens) > 510:
|
||||
tokens = tokens[:510]
|
||||
print('Truncated to 510 tokens')
|
||||
ref_s = voicepack[len(tokens)]
|
||||
out = forward(model, tokens, ref_s, speed)
|
||||
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||
return out, ps
|
||||
|
||||
def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||
ps = ps or phonemize(text, lang)
|
||||
tokens = tokenize(ps)
|
||||
if not tokens:
|
||||
return None
|
||||
outs = []
|
||||
loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
|
||||
for i in range(loop_count):
|
||||
ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
|
||||
out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
|
||||
outs.append(out)
|
||||
outs = np.concatenate(outs)
|
||||
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||
return outs, ps
|
||||
@@ -1,373 +0,0 @@
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/models.py
|
||||
from istftnet import AdaIN1d, Decoder
|
||||
from munch import Munch
|
||||
from pathlib import Path
|
||||
from plbert import load_plbert
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(n_symbols, channels)
|
||||
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.cnn = nn.ModuleList()
|
||||
for _ in range(depth):
|
||||
self.cnn.append(nn.Sequential(
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||
LayerNorm(channels),
|
||||
actv,
|
||||
nn.Dropout(0.2),
|
||||
))
|
||||
# self.cnn = nn.Sequential(*self.cnn)
|
||||
|
||||
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x, input_lengths, m):
|
||||
x = self.embedding(x) # [B, T, emb]
|
||||
x = x.transpose(1, 2) # [B, emb, T]
|
||||
m = m.to(input_lengths.device).unsqueeze(1)
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
for c in self.cnn:
|
||||
x = c(x)
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
x = x.transpose(1, 2) # [B, T, chn]
|
||||
|
||||
input_lengths = input_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
|
||||
x = x.transpose(-1, -2)
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
def inference(self, x):
|
||||
x = self.embedding(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.cnn(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, layer_type):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_type == 'none':
|
||||
return x
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
class AdainResBlk1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||
upsample='none', dropout_p=0.0):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample_type = upsample
|
||||
self.upsample = UpSample1d(upsample)
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if upsample == 'none':
|
||||
self.pool = nn.Identity()
|
||||
else:
|
||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = self.upsample(x)
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(self.dropout(x))
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(self.dropout(x))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||
return out
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, style_dim, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.fc = nn.Linear(style_dim, channels*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
x = x.transpose(-1, -2)
|
||||
x = x.transpose(1, -1)
|
||||
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||
|
||||
|
||||
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||
x = (1 + gamma) * x + beta
|
||||
return x.transpose(1, -1).transpose(-1, -2)
|
||||
|
||||
class ProsodyPredictor(nn.Module):
|
||||
|
||||
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
||||
d_model=d_hid,
|
||||
nlayers=nlayers,
|
||||
dropout=dropout)
|
||||
|
||||
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||
|
||||
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.F0 = nn.ModuleList()
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
|
||||
self.N = nn.ModuleList()
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
|
||||
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
|
||||
|
||||
def forward(self, texts, style, text_lengths, alignment, m):
|
||||
d = self.text_encoder(texts, style, text_lengths, m)
|
||||
|
||||
batch_size = d.shape[0]
|
||||
text_size = d.shape[1]
|
||||
|
||||
# predict duration
|
||||
input_lengths = text_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
d, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
|
||||
m = m.to(text_lengths.device).unsqueeze(1)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
|
||||
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
||||
|
||||
x_pad[:, :x.shape[1], :] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
||||
|
||||
en = (d.transpose(-1, -2) @ alignment)
|
||||
|
||||
return duration.squeeze(-1), en
|
||||
|
||||
def F0Ntrain(self, x, s):
|
||||
x, _ = self.shared(x.transpose(-1, -2))
|
||||
|
||||
F0 = x.transpose(-1, -2)
|
||||
for block in self.F0:
|
||||
F0 = block(F0, s)
|
||||
F0 = self.F0_proj(F0)
|
||||
|
||||
N = x.transpose(-1, -2)
|
||||
for block in self.N:
|
||||
N = block(N, s)
|
||||
N = self.N_proj(N)
|
||||
|
||||
return F0.squeeze(1), N.squeeze(1)
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
class DurationEncoder(nn.Module):
|
||||
|
||||
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||
super().__init__()
|
||||
self.lstms = nn.ModuleList()
|
||||
for _ in range(nlayers):
|
||||
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
||||
d_model // 2,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
dropout=dropout))
|
||||
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||
|
||||
|
||||
self.dropout = dropout
|
||||
self.d_model = d_model
|
||||
self.sty_dim = sty_dim
|
||||
|
||||
def forward(self, x, style, text_lengths, m):
|
||||
masks = m.to(text_lengths.device)
|
||||
|
||||
x = x.permute(2, 0, 1)
|
||||
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||
x = torch.cat([x, s], axis=-1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
input_lengths = text_lengths.cpu().numpy()
|
||||
x = x.transpose(-1, -2)
|
||||
|
||||
for block in self.lstms:
|
||||
if isinstance(block, AdaLayerNorm):
|
||||
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||
else:
|
||||
x = x.transpose(-1, -2)
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
block.flatten_parameters()
|
||||
x, _ = block(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = x.transpose(-1, -2)
|
||||
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
return x.transpose(-1, -2)
|
||||
|
||||
def inference(self, x, style):
|
||||
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
||||
style = style.expand(x.shape[0], x.shape[1], -1)
|
||||
x = torch.cat([x, style], axis=-1)
|
||||
src = self.pos_encoder(x)
|
||||
output = self.transformer_encoder(src).transpose(0, 1)
|
||||
return output
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
||||
def recursive_munch(d):
|
||||
if isinstance(d, dict):
|
||||
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
||||
elif isinstance(d, list):
|
||||
return [recursive_munch(v) for v in d]
|
||||
else:
|
||||
return d
|
||||
|
||||
def build_model(path, device):
|
||||
config = Path(__file__).parent / 'config.json'
|
||||
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||
with open(config, 'r') as r:
|
||||
args = recursive_munch(json.load(r))
|
||||
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||
upsample_rates = args.decoder.upsample_rates,
|
||||
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
||||
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
||||
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
||||
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
||||
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
||||
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
||||
bert = load_plbert()
|
||||
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
||||
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
||||
for child in parent.children():
|
||||
if isinstance(child, nn.RNNBase):
|
||||
child.flatten_parameters()
|
||||
model = Munch(
|
||||
bert=bert.to(device).eval(),
|
||||
bert_encoder=bert_encoder.to(device).eval(),
|
||||
predictor=predictor.to(device).eval(),
|
||||
decoder=decoder.to(device).eval(),
|
||||
text_encoder=text_encoder.to(device).eval(),
|
||||
)
|
||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
||||
assert key in model, key
|
||||
try:
|
||||
model[key].load_state_dict(state_dict)
|
||||
except:
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
model[key].load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
@@ -1,16 +0,0 @@
|
||||
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/plbert.py
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||
from transformers import AlbertConfig, AlbertModel
|
||||
|
||||
class CustomAlbert(AlbertModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
# Call the original forward method
|
||||
outputs = super().forward(*args, **kwargs)
|
||||
# Only return the last_hidden_state
|
||||
return outputs.last_hidden_state
|
||||
|
||||
def load_plbert():
|
||||
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
||||
albert_base_configuration = AlbertConfig(**plbert_config)
|
||||
bert = CustomAlbert(albert_base_configuration)
|
||||
return bert
|
||||
@@ -1,2 +1,6 @@
|
||||
torch==2.4.1
|
||||
transformers
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
transformers
|
||||
accelerate
|
||||
torch
|
||||
kokoro
|
||||
soundfile
|
||||
@@ -1,3 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.4.1+cu118
|
||||
transformers
|
||||
torch==2.7.1+cu118
|
||||
torchaudio==2.7.1+cu118
|
||||
transformers
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
@@ -1,2 +1,6 @@
|
||||
torch==2.4.1
|
||||
transformers
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
transformers
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
@@ -1,3 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.4.1+rocm6.0
|
||||
transformers
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
torchaudio==2.7.1+rocm6.3
|
||||
transformers
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
@@ -1,5 +1,11 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
transformers
|
||||
intel-extension-for-pytorch==2.8.10+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
torchaudio==2.5.1+cxx11.abi
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
@@ -1,7 +1,6 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
phonemizer
|
||||
scipy
|
||||
munch
|
||||
setuptools
|
||||
soundfile
|
||||
certifi
|
||||
packaging==24.1
|
||||
pip
|
||||
chardet
|
||||
87
backend/python/kokoro/test.py
Normal file
87
backend/python/kokoro/test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the Kokoro pipeline is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Kokoro TTS pipeline loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the TTS generation works successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
|
||||
self.assertTrue(response.success)
|
||||
tts_request = backend_pb2.TTSRequest(
|
||||
text="Kokoro is an open-weight TTS model with 82 million parameters.",
|
||||
voice="af_heart",
|
||||
dst="test_output.wav"
|
||||
)
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertIsNotNone(tts_response)
|
||||
self.assertTrue(tts_response.success)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
20
backend/python/rfdetr/Makefile
Normal file
20
backend/python/rfdetr/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
$(MAKE) protogen
|
||||
|
||||
.PHONY: protogen
|
||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
backend_pb2_grpc.py backend_pb2.py:
|
||||
bash protogen.sh
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
174
backend/python/rfdetr/backend.py
Executable file
174
backend/python/rfdetr/backend.py
Executable file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
gRPC server for RFDETR object detection models.
|
||||
"""
|
||||
from concurrent import futures
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
import requests
|
||||
|
||||
import supervision as sv
|
||||
from inference import get_model
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer for the RFDETR backend service.
|
||||
|
||||
This class implements the gRPC methods for object detection using RFDETR models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_name = None
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the health status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthMessage object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Reply object that contains the health status of the backend service.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""
|
||||
A gRPC method that loads a RFDETR model into memory.
|
||||
|
||||
Args:
|
||||
request: A ModelOptions object that contains the model parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
model_name = request.Model
|
||||
try:
|
||||
# Load the RFDETR model
|
||||
self.model = get_model(model_name)
|
||||
self.model_name = model_name
|
||||
print(f'Loaded RFDETR model: {model_name}')
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Failed to load model: {err}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Detect(self, request, context):
|
||||
"""
|
||||
A gRPC method that performs object detection on an image.
|
||||
|
||||
Args:
|
||||
request: A DetectOptions object that contains the image source.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A DetectResponse object that contains the detection results.
|
||||
"""
|
||||
if self.model is None:
|
||||
print(f"Model is None")
|
||||
return backend_pb2.DetectResponse()
|
||||
print(f"Model is not None")
|
||||
try:
|
||||
print(f"Decoding image")
|
||||
# Decode the base64 image
|
||||
print(f"Image data: {request.src}")
|
||||
|
||||
image_data = base64.b64decode(request.src)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
# Perform inference
|
||||
predictions = self.model.infer(image, confidence=0.5)[0]
|
||||
|
||||
# Convert to proto format
|
||||
proto_detections = []
|
||||
for i in range(len(predictions.predictions)):
|
||||
pred = predictions.predictions[i]
|
||||
print(f"Prediction: {pred}")
|
||||
proto_detection = backend_pb2.Detection(
|
||||
x=float(pred.x),
|
||||
y=float(pred.y),
|
||||
width=float(pred.width),
|
||||
height=float(pred.height),
|
||||
confidence=float(pred.confidence),
|
||||
class_name=pred.class_name
|
||||
)
|
||||
proto_detections.append(proto_detection)
|
||||
|
||||
return backend_pb2.DetectResponse(Detections=proto_detections)
|
||||
except Exception as err:
|
||||
print(f"Detection error: {err}")
|
||||
return backend_pb2.DetectResponse()
|
||||
|
||||
def Status(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthMessage object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A StatusResponse object that contains the status information.
|
||||
"""
|
||||
state = backend_pb2.StatusResponse.READY if self.model is not None else backend_pb2.StatusResponse.UNINITIALIZED
|
||||
return backend_pb2.StatusResponse(state=state)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("[RFDETR] Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("[RFDETR] Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the RFDETR gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(f"[RFDETR] startup: {args}", file=sys.stderr)
|
||||
serve(args.addr)
|
||||
|
||||
|
||||
|
||||
19
backend/python/rfdetr/install.sh
Executable file
19
backend/python/rfdetr/install.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
@@ -8,4 +8,6 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
ensureVenv
|
||||
|
||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||
7
backend/python/rfdetr/requirements-cpu.txt
Normal file
7
backend/python/rfdetr/requirements-cpu.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
rfdetr
|
||||
opencv-python
|
||||
accelerate
|
||||
peft
|
||||
inference
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
8
backend/python/rfdetr/requirements-cublas11.txt
Normal file
8
backend/python/rfdetr/requirements-cublas11.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.7.1+cu118
|
||||
rfdetr
|
||||
opencv-python
|
||||
accelerate
|
||||
inference
|
||||
peft
|
||||
optimum-quanto
|
||||
7
backend/python/rfdetr/requirements-cublas12.txt
Normal file
7
backend/python/rfdetr/requirements-cublas12.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
torch==2.7.1
|
||||
rfdetr
|
||||
opencv-python
|
||||
accelerate
|
||||
inference
|
||||
peft
|
||||
optimum-quanto
|
||||
9
backend/python/rfdetr/requirements-hipblas.txt
Normal file
9
backend/python/rfdetr/requirements-hipblas.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
torchvision==0.22.1+rocm6.3
|
||||
rfdetr
|
||||
opencv-python
|
||||
accelerate
|
||||
inference
|
||||
peft
|
||||
optimum-quanto
|
||||
13
backend/python/rfdetr/requirements-intel.txt
Normal file
13
backend/python/rfdetr/requirements-intel.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchvision==0.18.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
rfdetr
|
||||
inference
|
||||
opencv-python
|
||||
accelerate
|
||||
peft
|
||||
optimum-quanto
|
||||
3
backend/python/rfdetr/requirements.txt
Normal file
3
backend/python/rfdetr/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
9
backend/python/rfdetr/run.sh
Executable file
9
backend/python/rfdetr/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
11
backend/python/rfdetr/test.sh
Executable file
11
backend/python/rfdetr/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -22,7 +22,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
import outetts
|
||||
from sentence_transformers import SentenceTransformer
|
||||
@@ -90,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.OuteTTS=False
|
||||
self.DiaTTS=False
|
||||
self.SentenceTransformer = False
|
||||
|
||||
device_map="cpu"
|
||||
@@ -97,6 +98,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
quantization = None
|
||||
autoTokenizer = True
|
||||
|
||||
# Parse options from request.Options
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when generating
|
||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
# if value is a number, convert it to the appropriate type
|
||||
try:
|
||||
if "." in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
# Keep as string if conversion fails
|
||||
pass
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||
|
||||
if self.CUDA:
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
if request.MainGPU:
|
||||
@@ -202,6 +227,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "OuteTTS":
|
||||
autoTokenizer = False
|
||||
options = request.Options
|
||||
@@ -262,7 +297,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = 512
|
||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||
|
||||
if autoTokenizer:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
@@ -485,16 +520,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = 256
|
||||
if request.HasField('duration'):
|
||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||
guidance = 3.0
|
||||
guidance = self.options.get("guidance_scale", 3.0)
|
||||
if request.HasField('temperature'):
|
||||
guidance = request.temperature
|
||||
dosample = True
|
||||
dosample = self.options.get("do_sample", True)
|
||||
if request.HasField('sample'):
|
||||
dosample = request.sample
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
@@ -506,13 +540,59 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def OuteTTS(self, request, context):
|
||||
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest containing text dialogue and generation parameters
|
||||
context: The gRPC context
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure
|
||||
"""
|
||||
try:
|
||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
||||
|
||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
||||
text = [request.text]
|
||||
|
||||
# Process the input
|
||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
||||
|
||||
# Generate audio with parameters from options or defaults
|
||||
generation_params = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
||||
"temperature": self.options.get("temperature", 1.8),
|
||||
"top_p": self.options.get("top_p", 0.90),
|
||||
"top_k": self.options.get("top_k", 45)
|
||||
}
|
||||
|
||||
outputs = self.model.generate(**generation_params)
|
||||
|
||||
# Decode and save audio
|
||||
outputs = self.processor.batch_decode(outputs)
|
||||
self.processor.save_audio(outputs, request.dst)
|
||||
|
||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def CallOuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
text="Speech synthesis is the artificial production of human speech.",
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.1,
|
||||
temperature=self.options.get("temperature", 0.1),
|
||||
repetition_penalty=self.options.get("repetition_penalty", 1.1),
|
||||
max_length=self.max_tokens,
|
||||
speaker=self.speaker,
|
||||
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||
@@ -528,7 +608,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.OuteTTS(request, context)
|
||||
return self.CallOuteTTS(request, context)
|
||||
|
||||
if self.DiaTTS:
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
torch==2.4.1
|
||||
torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==3.4.1
|
||||
sentence-transformers==5.0.0
|
||||
protobuf==6.31.0
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.4.1+cu118
|
||||
torch==2.7.1+cu118
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==4.1.0
|
||||
sentence-transformers==5.0.0
|
||||
protobuf==6.31.0
|
||||
@@ -1,9 +1,9 @@
|
||||
torch==2.4.1
|
||||
torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==4.1.0
|
||||
sentence-transformers==5.0.0
|
||||
protobuf==6.31.0
|
||||
@@ -1,5 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.4.1+rocm6.0
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
accelerate
|
||||
transformers
|
||||
llvmlite==0.43.0
|
||||
@@ -7,5 +7,5 @@ numba==0.60.0
|
||||
bitsandbytes
|
||||
outetts
|
||||
bitsandbytes
|
||||
sentence-transformers==4.1.0
|
||||
sentence-transformers==5.0.0
|
||||
protobuf==6.31.0
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
@@ -9,5 +9,5 @@ transformers
|
||||
intel-extension-for-transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==4.1.0
|
||||
sentence-transformers==5.0.0
|
||||
protobuf==6.31.0
|
||||
@@ -59,8 +59,10 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, options.ExternalBackends...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backends")
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
|
||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||
|
||||
34
core/backend/detection.go
Normal file
34
core/backend/detection.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func Detection(
|
||||
sourceFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig,
|
||||
) (*proto.DetectResponse, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
detectionModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer loader.Close()
|
||||
|
||||
if detectionModel == nil {
|
||||
return nil, fmt.Errorf("could not load detection model")
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
})
|
||||
|
||||
return res, err
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
@@ -33,6 +33,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
Dst: dst,
|
||||
Src: src,
|
||||
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||
RefImages: refImages,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,9 @@ type BackendsList struct {
|
||||
}
|
||||
|
||||
type BackendsInstall struct {
|
||||
BackendArgs []string `arg:"" optional:"" name:"backends" help:"Backend configuration URLs to load"`
|
||||
BackendArgs string `arg:"" optional:"" name:"backend" help:"Backend configuration URL to load"`
|
||||
Name string `arg:"" optional:"" name:"name" help:"Name of the backend"`
|
||||
Alias string `arg:"" optional:"" name:"alias" help:"Alias of the backend"`
|
||||
|
||||
BackendsCMDFlags `embed:""`
|
||||
}
|
||||
@@ -66,27 +68,25 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
}
|
||||
|
||||
for _, backendName := range bi.BackendArgs {
|
||||
|
||||
progressBar := progressbar.NewOptions(
|
||||
1000,
|
||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", backendName)),
|
||||
progressbar.OptionShowBytes(false),
|
||||
progressbar.OptionClearOnFinish(),
|
||||
)
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
v := int(percentage * 10)
|
||||
err := progressBar.Set(v)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar")
|
||||
}
|
||||
}
|
||||
|
||||
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, backendName)
|
||||
progressBar := progressbar.NewOptions(
|
||||
1000,
|
||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)),
|
||||
progressbar.OptionShowBytes(false),
|
||||
progressbar.OptionClearOnFinish(),
|
||||
)
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
v := int(percentage * 10)
|
||||
err := progressBar.Set(v)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar")
|
||||
}
|
||||
}
|
||||
|
||||
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ type RunCMD struct {
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
|
||||
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
||||
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
|
||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||
@@ -88,7 +87,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
||||
config.WithGeneratedContentDir(r.GeneratedContentPath),
|
||||
config.WithUploadDir(r.UploadPath),
|
||||
config.WithConfigsDir(r.ConfigPath),
|
||||
config.WithDynamicConfigDir(r.LocalaiConfigDir),
|
||||
config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval),
|
||||
config.WithF16(r.F16),
|
||||
|
||||
@@ -21,8 +21,7 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
GeneratedContentDir string
|
||||
|
||||
ConfigsDir string
|
||||
UploadDir string
|
||||
UploadDir string
|
||||
|
||||
DynamicConfigsDir string
|
||||
DynamicConfigsDirPollInterval time.Duration
|
||||
@@ -302,12 +301,6 @@ func WithUploadDir(uploadDir string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithConfigsDir(configsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ConfigsDir = configsDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
|
||||
@@ -458,6 +458,7 @@ const (
|
||||
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
||||
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
||||
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
||||
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -479,6 +480,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,6 +574,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||
if c.Backend != "rfdetr" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
if c.Backend != "transformers-musicgen" {
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,7 @@ package gallery
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BackendMetadata represents the metadata stored in a JSON file for each installed backend
|
||||
@@ -23,6 +24,7 @@ type GalleryBackend struct {
|
||||
Metadata `json:",inline" yaml:",inline"`
|
||||
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
||||
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
@@ -33,9 +35,11 @@ func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.Syste
|
||||
|
||||
realBackend := backend.CapabilitiesMap[systemState.Capability(backend.CapabilitiesMap)]
|
||||
if realBackend == "" {
|
||||
log.Debug().Str("backend", backend.Name).Str("reportedCapability", systemState.Capability(backend.CapabilitiesMap)).Msg("No backend found for reported capability")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug().Str("backend", backend.Name).Str("reportedCapability", systemState.Capability(backend.CapabilitiesMap)).Msg("Found backend for reported capability")
|
||||
return backends.FindByName(realBackend)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
cp "github.com/otiai10/copy"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -145,8 +146,28 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
|
||||
}
|
||||
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
|
||||
// Check if it is a directory
|
||||
if uri.LooksLikeDir() {
|
||||
// It is a directory, we just copy it over in the backend folder
|
||||
if err := cp.Copy(config.URI, backendPath); err != nil {
|
||||
return fmt.Errorf("failed copying: %w", err)
|
||||
}
|
||||
} else {
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !success {
|
||||
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create metadata for the backend
|
||||
@@ -229,16 +250,22 @@ func ListSystemBackends(basePath string) (map[string]string, error) {
|
||||
for _, backend := range backends {
|
||||
if backend.IsDir() {
|
||||
runFile := filepath.Join(basePath, backend.Name(), runFile)
|
||||
// Skip if metadata file don't exist
|
||||
|
||||
var metadata *BackendMetadata
|
||||
|
||||
// If metadata file does not exist, we just use the directory name
|
||||
// and we do not fill the other metadata (such as potential backend Aliases)
|
||||
metadataFilePath := filepath.Join(basePath, backend.Name(), metadataFile)
|
||||
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for alias in metadata
|
||||
metadata, err := readBackendMetadata(filepath.Join(basePath, backend.Name()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
metadata = &BackendMetadata{
|
||||
Name: backend.Name(),
|
||||
}
|
||||
} else {
|
||||
// Check for alias in metadata
|
||||
metadata, err = readBackendMetadata(filepath.Join(basePath, backend.Name()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if metadata == nil {
|
||||
|
||||
@@ -95,7 +95,7 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
|
||||
|
||||
if !strings.Contains(name, "@") {
|
||||
for _, m := range models {
|
||||
if strings.EqualFold(m.GetName(), name) {
|
||||
if strings.EqualFold(strings.ToLower(m.GetName()), strings.ToLower(name)) {
|
||||
model = m
|
||||
break
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
|
||||
|
||||
} else {
|
||||
for _, m := range models {
|
||||
if strings.EqualFold(name, fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName())) {
|
||||
if strings.EqualFold(strings.ToLower(name), strings.ToLower(fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName()))) {
|
||||
model = m
|
||||
break
|
||||
}
|
||||
|
||||
@@ -10,10 +10,8 @@ import (
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
@@ -199,11 +197,6 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
router.Use(csrf.New())
|
||||
}
|
||||
|
||||
// Load config jsons
|
||||
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||
|
||||
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
|
||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||
if err != nil {
|
||||
|
||||
@@ -34,7 +34,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, backendPath string
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.BackendOpStatus "Response"
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
@@ -48,7 +48,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.BackendOpStatus "Response"
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
@@ -58,7 +58,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
||||
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||
// @Summary Install backends to LocalAI.
|
||||
// @Param request body BackendModel true "query params"
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
|
||||
59
core/http/endpoints/localai/detection.go
Normal file
59
core/http/endpoints/localai/detection.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
||||
// @Summary Detects objects in the input image.
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
|
||||
image, err := utils.GetContentURIAsBase64(input.Image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := schema.DetectionResponse{
|
||||
Detections: make([]schema.Detection, len(res.Detections)),
|
||||
}
|
||||
for i, detection := range res.Detections {
|
||||
response.Detections[i] = schema.Detection{
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
@@ -1,522 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ToolType defines a type for tool options
|
||||
type ToolType string
|
||||
|
||||
const (
|
||||
CodeInterpreter ToolType = "code_interpreter"
|
||||
Retrieval ToolType = "retrieval"
|
||||
Function ToolType = "function"
|
||||
|
||||
MaxCharacterInstructions = 32768
|
||||
MaxCharacterDescription = 512
|
||||
MaxCharacterName = 256
|
||||
MaxToolsSize = 128
|
||||
MaxFileIdSize = 20
|
||||
MaxCharacterMetadataKey = 64
|
||||
MaxCharacterMetadataValue = 512
|
||||
)
|
||||
|
||||
type Tool struct {
|
||||
Type ToolType `json:"type"`
|
||||
}
|
||||
|
||||
// Assistant represents the structure of an assistant object from the OpenAI API.
|
||||
type Assistant struct {
|
||||
ID string `json:"id"` // The unique identifier of the assistant.
|
||||
Object string `json:"object"` // Object type, which is "assistant".
|
||||
Created int64 `json:"created"` // The time at which the assistant was created.
|
||||
Model string `json:"model"` // The model ID used by the assistant.
|
||||
Name string `json:"name,omitempty"` // The name of the assistant.
|
||||
Description string `json:"description,omitempty"` // The description of the assistant.
|
||||
Instructions string `json:"instructions,omitempty"` // The system instructions that the assistant uses.
|
||||
Tools []Tool `json:"tools,omitempty"` // A list of tools enabled on the assistant.
|
||||
FileIDs []string `json:"file_ids,omitempty"` // A list of file IDs attached to this assistant.
|
||||
Metadata map[string]string `json:"metadata,omitempty"` // Set of key-value pairs attached to the assistant.
|
||||
}
|
||||
|
||||
var (
|
||||
Assistants = []Assistant{} // better to return empty array instead of "null"
|
||||
AssistantsConfigFile = "assistants.json"
|
||||
)
|
||||
|
||||
type AssistantRequest struct {
|
||||
Model string `json:"model"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
FileIDs []string `json:"file_ids,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// CreateAssistantEndpoint is the OpenAI Assistant API endpoint https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
||||
// @Summary Create an assistant with a model and instructions.
|
||||
// @Param request body AssistantRequest true "query params"
|
||||
// @Success 200 {object} Assistant "Response"
|
||||
// @Router /v1/assistants [post]
|
||||
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
request := new(AssistantRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if !modelExists(cl, ml, request.Model) {
|
||||
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
|
||||
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Model %q not found", request.Model)))
|
||||
}
|
||||
|
||||
if request.Tools == nil {
|
||||
request.Tools = []Tool{}
|
||||
}
|
||||
|
||||
if request.FileIDs == nil {
|
||||
request.FileIDs = []string{}
|
||||
}
|
||||
|
||||
if request.Metadata == nil {
|
||||
request.Metadata = make(map[string]string)
|
||||
}
|
||||
|
||||
id := "asst_" + strconv.FormatInt(generateRandomID(), 10)
|
||||
|
||||
assistant := Assistant{
|
||||
ID: id,
|
||||
Object: "assistant",
|
||||
Created: time.Now().Unix(),
|
||||
Model: request.Model,
|
||||
Name: request.Name,
|
||||
Description: request.Description,
|
||||
Instructions: request.Instructions,
|
||||
Tools: request.Tools,
|
||||
FileIDs: request.FileIDs,
|
||||
Metadata: request.Metadata,
|
||||
}
|
||||
|
||||
Assistants = append(Assistants, assistant)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||
}
|
||||
}
|
||||
|
||||
var currentId int64 = 0
|
||||
|
||||
func generateRandomID() int64 {
|
||||
atomic.AddInt64(¤tId, 1)
|
||||
return currentId
|
||||
}
|
||||
|
||||
// ListAssistantsEndpoint is the OpenAI Assistant API endpoint to list assistents https://platform.openai.com/docs/api-reference/assistants/listAssistants
|
||||
// @Summary List available assistents
|
||||
// @Param limit query int false "Limit the number of assistants returned"
|
||||
// @Param order query string false "Order of assistants returned"
|
||||
// @Param after query string false "Return assistants created after the given ID"
|
||||
// @Param before query string false "Return assistants created before the given ID"
|
||||
// @Success 200 {object} []Assistant "Response"
|
||||
// @Router /v1/assistants [get]
|
||||
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Because we're altering the existing assistants list we should just duplicate it for now.
|
||||
returnAssistants := Assistants
|
||||
// Parse query parameters
|
||||
limitQuery := c.Query("limit", "20")
|
||||
orderQuery := c.Query("order", "desc")
|
||||
afterQuery := c.Query("after")
|
||||
beforeQuery := c.Query("before")
|
||||
|
||||
// Convert string limit to integer
|
||||
limit, err := strconv.Atoi(limitQuery)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Invalid limit query value: %s", limitQuery)))
|
||||
}
|
||||
|
||||
// Sort assistants
|
||||
sort.SliceStable(returnAssistants, func(i, j int) bool {
|
||||
if orderQuery == "asc" {
|
||||
return returnAssistants[i].Created < returnAssistants[j].Created
|
||||
}
|
||||
return returnAssistants[i].Created > returnAssistants[j].Created
|
||||
})
|
||||
|
||||
// After and before cursors
|
||||
if afterQuery != "" {
|
||||
returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery)
|
||||
}
|
||||
if beforeQuery != "" {
|
||||
returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery)
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if limit < len(returnAssistants) {
|
||||
returnAssistants = returnAssistants[:limit]
|
||||
}
|
||||
|
||||
return c.JSON(returnAssistants)
|
||||
}
|
||||
}
|
||||
|
||||
// FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID
|
||||
// We assume that the assistants are already sorted
|
||||
func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant {
|
||||
idInt, err := strconv.Atoi(id)
|
||||
if err != nil {
|
||||
return assistants // Return original slice if invalid id format is provided
|
||||
}
|
||||
|
||||
var filteredAssistants []Assistant
|
||||
|
||||
for _, assistant := range assistants {
|
||||
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||
if err != nil {
|
||||
continue // Skip if invalid id in assistant
|
||||
}
|
||||
|
||||
if aid < idInt {
|
||||
filteredAssistants = append(filteredAssistants, assistant)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredAssistants
|
||||
}
|
||||
|
||||
// FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID
|
||||
// We assume that the assistants are already sorted
|
||||
func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
|
||||
idInt, err := strconv.Atoi(id)
|
||||
if err != nil {
|
||||
return assistants // Return original slice if invalid id format is provided
|
||||
}
|
||||
|
||||
var filteredAssistants []Assistant
|
||||
|
||||
for _, assistant := range assistants {
|
||||
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||
if err != nil {
|
||||
continue // Skip if invalid id in assistant
|
||||
}
|
||||
|
||||
if aid > idInt {
|
||||
filteredAssistants = append(filteredAssistants, assistant)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredAssistants
|
||||
}
|
||||
|
||||
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
|
||||
found = false
|
||||
models, err := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == modelName {
|
||||
found = true
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteAssistantEndpoint is the OpenAI Assistant API endpoint to delete assistents https://platform.openai.com/docs/api-reference/assistants/deleteAssistant
|
||||
// @Summary Delete assistents
|
||||
// @Success 200 {object} schema.DeleteAssistantResponse "Response"
|
||||
// @Router /v1/assistants/{assistant_id} [delete]
|
||||
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
}
|
||||
|
||||
for i, assistant := range Assistants {
|
||||
if assistant.ID == assistantID {
|
||||
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantResponse{
|
||||
ID: assistantID,
|
||||
Object: "assistant.deleted",
|
||||
Deleted: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID)
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantResponse{
|
||||
ID: assistantID,
|
||||
Object: "assistant.deleted",
|
||||
Deleted: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetAssistantEndpoint is the OpenAI Assistant API endpoint to get assistents https://platform.openai.com/docs/api-reference/assistants/getAssistant
|
||||
// @Summary Get assistent data
|
||||
// @Success 200 {object} Assistant "Response"
|
||||
// @Router /v1/assistants/{assistant_id} [get]
|
||||
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
}
|
||||
|
||||
for _, assistant := range Assistants {
|
||||
if assistant.ID == assistantID {
|
||||
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
|
||||
}
|
||||
}
|
||||
|
||||
type AssistantFile struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
AssistantID string `json:"assistant_id"`
|
||||
}
|
||||
|
||||
var (
|
||||
AssistantFiles []AssistantFile
|
||||
AssistantsFileConfigFile = "assistantsFile.json"
|
||||
)
|
||||
|
||||
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
request := new(schema.AssistantFileRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
}
|
||||
|
||||
for _, assistant := range Assistants {
|
||||
if assistant.ID == assistantID {
|
||||
if len(assistant.FileIDs) > MaxFileIdSize {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name))
|
||||
}
|
||||
|
||||
for _, file := range UploadedFiles {
|
||||
if file.ID == request.FileID {
|
||||
assistant.FileIDs = append(assistant.FileIDs, request.FileID)
|
||||
assistantFile := AssistantFile{
|
||||
ID: file.ID,
|
||||
Object: "assistant.file",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
AssistantID: assistant.ID,
|
||||
}
|
||||
AssistantFiles = append(AssistantFiles, assistantFile)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find file_id: %s", request.FileID)))
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find %q", assistantID)))
|
||||
}
|
||||
}
|
||||
|
||||
func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
type ListAssistantFiles struct {
|
||||
Data []schema.File
|
||||
Object string
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
}
|
||||
|
||||
limitQuery := c.Query("limit", "20")
|
||||
order := c.Query("order", "desc")
|
||||
limit, err := strconv.Atoi(limitQuery)
|
||||
if err != nil || limit < 1 || limit > 100 {
|
||||
limit = 20 // Default to 20 if there's an error or the limit is out of bounds
|
||||
}
|
||||
|
||||
// Sort files by CreatedAt depending on the order query parameter
|
||||
if order == "asc" {
|
||||
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||
return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt
|
||||
})
|
||||
} else { // default to "desc"
|
||||
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||
return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt
|
||||
})
|
||||
}
|
||||
|
||||
// Limit the number of files returned
|
||||
var limitedFiles []AssistantFile
|
||||
hasMore := false
|
||||
if len(AssistantFiles) > limit {
|
||||
hasMore = true
|
||||
limitedFiles = AssistantFiles[:limit]
|
||||
} else {
|
||||
limitedFiles = AssistantFiles
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": limitedFiles,
|
||||
"first_id": func() string {
|
||||
if len(limitedFiles) > 0 {
|
||||
return limitedFiles[0].ID
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
"last_id": func() string {
|
||||
if len(limitedFiles) > 0 {
|
||||
return limitedFiles[len(limitedFiles)-1].ID
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
"has_more": hasMore,
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
}
|
||||
}
|
||||
|
||||
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
request := new(AssistantRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
}
|
||||
|
||||
for i, assistant := range Assistants {
|
||||
if assistant.ID == assistantID {
|
||||
newAssistant := Assistant{
|
||||
ID: assistantID,
|
||||
Object: assistant.Object,
|
||||
Created: assistant.Created,
|
||||
Model: request.Model,
|
||||
Name: request.Name,
|
||||
Description: request.Description,
|
||||
Instructions: request.Instructions,
|
||||
Tools: request.Tools,
|
||||
FileIDs: request.FileIDs, // todo: should probably verify fileids exist
|
||||
Metadata: request.Metadata,
|
||||
}
|
||||
|
||||
// Remove old one and replace with new one
|
||||
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||
Assistants = append(Assistants, newAssistant)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||
return c.Status(fiber.StatusOK).JSON(newAssistant)
|
||||
}
|
||||
}
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
assistantID := c.Params("assistant_id")
|
||||
fileId := c.Params("file_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||
}
|
||||
// First remove file from assistant
|
||||
for i, assistant := range Assistants {
|
||||
if assistant.ID == assistantID {
|
||||
for j, fileId := range assistant.FileIDs {
|
||||
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...)
|
||||
|
||||
// Check if the file exists in the assistantFiles slice
|
||||
for i, assistantFile := range AssistantFiles {
|
||||
if assistantFile.ID == fileId {
|
||||
// Remove the file from the assistantFiles slice
|
||||
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantFileResponse{
|
||||
ID: fileId,
|
||||
Object: "assistant.file.deleted",
|
||||
Deleted: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID)
|
||||
for i, assistantFile := range AssistantFiles {
|
||||
if assistantFile.AssistantID == assistantID {
|
||||
|
||||
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{
|
||||
ID: fileId,
|
||||
Object: "assistant.file.deleted",
|
||||
Deleted: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Warn().Msgf("Unable to find assistant: %s", assistantID)
|
||||
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{
|
||||
ID: fileId,
|
||||
Object: "assistant.file.deleted",
|
||||
Deleted: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
assistantID := c.Params("assistant_id")
|
||||
fileId := c.Params("file_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||
}
|
||||
|
||||
for _, assistantFile := range AssistantFiles {
|
||||
if assistantFile.AssistantID == assistantID {
|
||||
if assistantFile.ID == fileId {
|
||||
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||
}
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)))
|
||||
}
|
||||
}
|
||||
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)))
|
||||
}
|
||||
}
|
||||
@@ -1,460 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var configsDir string = "/tmp/localai/configs"
|
||||
|
||||
type MockLoader struct {
|
||||
models []string
|
||||
}
|
||||
|
||||
func tearDown() func() {
|
||||
return func() {
|
||||
UploadedFiles = []schema.File{}
|
||||
Assistants = []Assistant{}
|
||||
AssistantFiles = []AssistantFile{}
|
||||
_ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile))
|
||||
_ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssistantEndpoints(t *testing.T) {
|
||||
// Preparing the mocked objects
|
||||
cl := &config.BackendConfigLoader{}
|
||||
//configsDir := "/tmp/localai/configs"
|
||||
modelPath := "/tmp/localai/model"
|
||||
var ml = model.NewModelLoader(modelPath, false)
|
||||
|
||||
appConfig := &config.ApplicationConfig{
|
||||
ConfigsDir: configsDir,
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
ModelPath: modelPath,
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(appConfig.ConfigsDir)
|
||||
_ = os.MkdirAll(appConfig.ConfigsDir, 0750)
|
||||
_ = os.MkdirAll(modelPath, 0750)
|
||||
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j"))
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
|
||||
app.Post("/files", UploadFilesEndpoint(cl, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
|
||||
t.Run("CreateAssistantEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, resp, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, 1, len(Assistants))
|
||||
//t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||
|
||||
assert.Equal(t, ar.Name, resultAssistant.Name)
|
||||
assert.Equal(t, ar.Model, resultAssistant.Model)
|
||||
assert.Equal(t, ar.Tools, resultAssistant.Tools)
|
||||
assert.Equal(t, ar.Description, resultAssistant.Description)
|
||||
assert.Equal(t, ar.Instructions, resultAssistant.Instructions)
|
||||
assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs)
|
||||
assert.Equal(t, ar.Metadata, resultAssistant.Metadata)
|
||||
})
|
||||
|
||||
t.Run("ListAssistantsEndpoint", func(t *testing.T) {
|
||||
var ids []string
|
||||
var resultAssistant []Assistant
|
||||
for i := 0; i < 4; i++ {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: fmt.Sprintf("3.5-turbo-%d", i),
|
||||
Description: fmt.Sprintf("Test Assistant - %d", i),
|
||||
Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i),
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: []string{"fid-1234"},
|
||||
Metadata: map[string]string{"meta": "data"},
|
||||
}
|
||||
|
||||
//var err error
|
||||
ra, _, err := createAssistant(app, *ar)
|
||||
// Because we create the assistants so fast all end up with the same created time.
|
||||
time.Sleep(time.Second)
|
||||
resultAssistant = append(resultAssistant, ra)
|
||||
assert.NoError(t, err)
|
||||
ids = append(ids, resultAssistant[i].ID)
|
||||
}
|
||||
|
||||
t.Cleanup(cleanupAllAssistants(t, app, ids))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
reqURL string
|
||||
expectedStatus int
|
||||
expectedResult []Assistant
|
||||
expectedStringResult string
|
||||
}{
|
||||
{
|
||||
name: "Valid Usage - limit only",
|
||||
reqURL: "/assistants?limit=2",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants[:2], // Expecting the first two assistants
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - order asc",
|
||||
reqURL: "/assistants?order=asc",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants, // Expecting all assistants in ascending order
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - order desc",
|
||||
reqURL: "/assistants?order=desc",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - after specific ID",
|
||||
reqURL: "/assistants?after=2",
|
||||
expectedStatus: http.StatusOK,
|
||||
// Note this is correct because it's put in descending order already
|
||||
expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - before specific ID",
|
||||
reqURL: "/assistants?before=4",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3.
|
||||
},
|
||||
{
|
||||
name: "Invalid Usage - non-integer limit",
|
||||
reqURL: "/assistants?limit=two",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStringResult: "Invalid limit query value: two",
|
||||
},
|
||||
{
|
||||
name: "Invalid Usage - non-existing id in after",
|
||||
reqURL: "/assistants?after=100",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
||||
if tt.expectedStatus != fiber.StatusOK {
|
||||
all, _ := io.ReadAll(response.Body)
|
||||
assert.Equal(t, tt.expectedStringResult, string(all))
|
||||
} else {
|
||||
var result []Assistant
|
||||
err = json.NewDecoder(response.Body).Decode(&result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
_, err = app.Test(deleteReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(Assistants))
|
||||
})
|
||||
|
||||
t.Run("GetAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var getAssistant Assistant
|
||||
err = json.NewDecoder(response.Body).Decode(&getAssistant)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, resultAssistant.ID, getAssistant.ID)
|
||||
})
|
||||
|
||||
t.Run("ModifyAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
|
||||
modifiedAr := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "4.0-turbo",
|
||||
Description: "Modified Test Assistant",
|
||||
Instructions: "You are math teacher answering student questions",
|
||||
Tools: []Tool{{Type: CodeInterpreter}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
modifiedArJson, err := json.Marshal(modifiedAr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
|
||||
modifyResponse, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
var getAssistant Assistant
|
||||
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
||||
|
||||
assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change
|
||||
assert.Equal(t, modifiedAr.Tools, getAssistant.Tools)
|
||||
assert.Equal(t, modifiedAr.Name, getAssistant.Name)
|
||||
assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions)
|
||||
assert.Equal(t, modifiedAr.Description, getAssistant.Description)
|
||||
})
|
||||
|
||||
t.Run("CreateAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||
})
|
||||
t.Run("ListAssistantFilesEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||
})
|
||||
t.Run("GetAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID))
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID)
|
||||
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var assistantFile AssistantFile
|
||||
err = json.NewDecoder(response.Body).Decode(&assistantFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, af.ID, assistantFile.ID)
|
||||
assert.Equal(t, af.AssistantID, assistantFile.AssistantID)
|
||||
})
|
||||
t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleanupAssistantFile(t, app, af.ID, af.AssistantID)()
|
||||
|
||||
assert.Empty(t, AssistantFiles)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (schema.File, Assistant, error) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
assistant, _, err := createAssistant(app, *ar)
|
||||
if err != nil {
|
||||
return schema.File{}, Assistant{}, err
|
||||
}
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID}))
|
||||
|
||||
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o)
|
||||
t.Cleanup(func() {
|
||||
_, err := CallFilesDeleteEndpoint(t, app, file.ID)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
return file, assistant, nil
|
||||
}
|
||||
|
||||
func createAssistantFile(app *fiber.App, afr schema.AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) {
|
||||
afrJson, err := json.Marshal(afr)
|
||||
if err != nil {
|
||||
return AssistantFile{}, nil, err
|
||||
}
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s/files", assistantId)
|
||||
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
if err != nil {
|
||||
return AssistantFile{}, resp, err
|
||||
}
|
||||
|
||||
var assistantFile AssistantFile
|
||||
all, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return AssistantFile{}, resp, err
|
||||
}
|
||||
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
||||
if err != nil {
|
||||
return AssistantFile{}, resp, err
|
||||
}
|
||||
|
||||
return assistantFile, resp, nil
|
||||
}
|
||||
|
||||
func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) {
|
||||
assistant, err := json.Marshal(ar)
|
||||
if err != nil {
|
||||
return Assistant{}, nil, err
|
||||
}
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
if err != nil {
|
||||
return Assistant{}, resp, err
|
||||
}
|
||||
|
||||
bodyString, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return Assistant{}, resp, err
|
||||
}
|
||||
|
||||
var resultAssistant Assistant
|
||||
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
||||
return resultAssistant, resp, err
|
||||
}
|
||||
|
||||
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
||||
return func() {
|
||||
for _, assistant := range ids {
|
||||
target := fmt.Sprintf("/assistants/%s", assistant)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
_, err := app.Test(deleteReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete assistant %s: %v", assistant, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() {
|
||||
return func() {
|
||||
target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId)
|
||||
request := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var dafr schema.DeleteAssistantFileResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&dafr)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, dafr.Deleted)
|
||||
}
|
||||
}
|
||||
@@ -305,7 +305,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
// If we are using the tokenizer template, we don't need to process the messages
|
||||
// unless we are processing functions
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if config.Grammar != "" {
|
||||
|
||||
@@ -109,8 +109,10 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
Input: predInput,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: predInput,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
ReasoningEffort: input.ReasoningEffort,
|
||||
Metadata: input.Metadata,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
@@ -160,8 +162,10 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
||||
|
||||
for k, i := range config.PromptStrings {
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
ReasoningEffort: input.ReasoningEffort,
|
||||
Metadata: input.Metadata,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
|
||||
@@ -47,9 +47,11 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
|
||||
for _, i := range config.InputStrings {
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
ReasoningEffort: input.ReasoningEffort,
|
||||
Metadata: input.Metadata,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
var UploadedFiles []schema.File
|
||||
|
||||
const UploadedFilesFile = "uploadedFiles.json"
|
||||
|
||||
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check the file size
|
||||
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
|
||||
}
|
||||
|
||||
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||
if purpose == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
|
||||
}
|
||||
|
||||
// Sanitize the filename to prevent directory traversal
|
||||
filename := utils.SanitizeFileName(file.Filename)
|
||||
|
||||
savePath := filepath.Join(appConfig.UploadDir, filename)
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
|
||||
}
|
||||
|
||||
err = c.SaveFile(file, savePath)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
f := schema.File{
|
||||
ID: fmt.Sprintf("file-%d", getNextFileId()),
|
||||
Object: "file",
|
||||
Bytes: int(file.Size),
|
||||
CreatedAt: time.Now(),
|
||||
Filename: file.Filename,
|
||||
Purpose: purpose,
|
||||
}
|
||||
|
||||
UploadedFiles = append(UploadedFiles, f)
|
||||
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||
return c.Status(fiber.StatusOK).JSON(f)
|
||||
}
|
||||
}
|
||||
|
||||
var currentFileId int64 = 0
|
||||
|
||||
func getNextFileId() int64 {
|
||||
atomic.AddInt64(¤tId, 1)
|
||||
return currentId
|
||||
}
|
||||
|
||||
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||
// @Summary List files.
|
||||
// @Success 200 {object} schema.ListFiles "Response"
|
||||
// @Router /v1/files [get]
|
||||
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
var listFiles schema.ListFiles
|
||||
|
||||
purpose := c.Query("purpose")
|
||||
if purpose == "" {
|
||||
listFiles.Data = UploadedFiles
|
||||
} else {
|
||||
for _, f := range UploadedFiles {
|
||||
if purpose == f.Purpose {
|
||||
listFiles.Data = append(listFiles.Data, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
listFiles.Object = "list"
|
||||
return c.Status(fiber.StatusOK).JSON(listFiles)
|
||||
}
|
||||
}
|
||||
|
||||
func getFileFromRequest(c *fiber.Ctx) (*schema.File, error) {
|
||||
id := c.Params("file_id")
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("file_id parameter is required")
|
||||
}
|
||||
|
||||
for _, f := range UploadedFiles {
|
||||
if id == f.ID {
|
||||
return &f, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to find file id %s", id)
|
||||
}
|
||||
|
||||
// GetFilesEndpoint is the OpenAI API endpoint to get files https://platform.openai.com/docs/api-reference/files/retrieve
|
||||
// @Summary Returns information about a specific file.
|
||||
// @Success 200 {object} schema.File "Response"
|
||||
// @Router /v1/files/{file_id} [get]
|
||||
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
return c.JSON(file)
|
||||
}
|
||||
}
|
||||
|
||||
type DeleteStatus struct {
|
||||
Id string
|
||||
Object string
|
||||
Deleted bool
|
||||
}
|
||||
|
||||
// DeleteFilesEndpoint is the OpenAI API endpoint to delete files https://platform.openai.com/docs/api-reference/files/delete
|
||||
// @Summary Delete a file.
|
||||
// @Success 200 {object} DeleteStatus "Response"
|
||||
// @Router /v1/files/{file_id} [delete]
|
||||
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||
if err != nil {
|
||||
// If the file doesn't exist then we should just continue to remove it
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err)))
|
||||
}
|
||||
}
|
||||
|
||||
// Remove upload from list
|
||||
for i, f := range UploadedFiles {
|
||||
if f.ID == file.ID {
|
||||
UploadedFiles = append(UploadedFiles[:i], UploadedFiles[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||
return c.JSON(DeleteStatus{
|
||||
Id: file.ID,
|
||||
Object: "file",
|
||||
Deleted: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetFilesContentsEndpoint is the OpenAI API endpoint to get files content https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||
// @Summary Returns information about a specific file.
|
||||
// @Success 200 {string} binary "file"
|
||||
// @Router /v1/files/{file_id}/content [get]
|
||||
// GetFilesContentsEndpoint
|
||||
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
return c.Send(fileContents)
|
||||
}
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
utils2 "github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
|
||||
// Preparing the mocked objects
|
||||
loader = &config.BackendConfigLoader{}
|
||||
|
||||
option = &config.ApplicationConfig{
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(option.UploadDir)
|
||||
|
||||
app = fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||
// Preparing the mocked objects
|
||||
loader := &config.BackendConfigLoader{}
|
||||
|
||||
option := &config.ApplicationConfig{
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(option.UploadDir)
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||
|
||||
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||
fmt.Println(f1)
|
||||
fmt.Printf("ERror: %v\n", err)
|
||||
fmt.Printf("resp: %+v\n", resp)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
// Check if file exists in the disk
|
||||
testName := strings.Split(t.Name(), "/")[1]
|
||||
fileName := testName + "-test.txt"
|
||||
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName(fileName))
|
||||
_, err := os.Stat(filePath)
|
||||
|
||||
assert.False(t, os.IsNotExist(err))
|
||||
assert.Equal(t, file.Bytes, 5242880)
|
||||
assert.NotEmpty(t, file.CreatedAt)
|
||||
assert.Equal(t, file.Filename, fileName)
|
||||
assert.Equal(t, file.Purpose, "fine-tune")
|
||||
})
|
||||
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
resp, err := CallListFilesEndpoint(t, app, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
if len(listFiles.Data) != len(UploadedFiles) {
|
||||
t.Errorf("Expected %v files, got %v files", len(UploadedFiles), len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||
assert.NoError(t, err)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
if len(listFiles.Data) != 1 {
|
||||
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
|
||||
if len(listFiles.Data) != 0 {
|
||||
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
req := httptest.NewRequest("GET", "/files", nil)
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var listFiles schema.ListFiles
|
||||
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
|
||||
t.Errorf("Failed to decode response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(listFiles.Data) != 0 {
|
||||
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
|
||||
var target string
|
||||
if purpose != "" {
|
||||
target = fmt.Sprintf("/files?purpose=%s", purpose)
|
||||
} else {
|
||||
target = "/files"
|
||||
}
|
||||
req := httptest.NewRequest("GET", target, nil)
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
|
||||
return app.Test(request)
|
||||
}
|
||||
|
||||
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||
testName := strings.Split(t.Name(), "/")[1]
|
||||
|
||||
// Create a file that exceeds the limit
|
||||
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||
|
||||
// Creating a new HTTP Request
|
||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) schema.File {
|
||||
// Create a file that exceeds the limit
|
||||
testName := strings.Split(t.Name(), "/")[1]
|
||||
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||
|
||||
// Creating a new HTTP Request
|
||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
f := responseToFile(t, resp)
|
||||
|
||||
//id := f.ID
|
||||
//t.Cleanup(func() {
|
||||
// _, err := CallFilesDeleteEndpoint(t, app, id)
|
||||
// assert.NoError(t, err)
|
||||
// assert.Empty(t, UploadedFiles)
|
||||
//})
|
||||
|
||||
return f
|
||||
|
||||
}
|
||||
|
||||
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||
target := fmt.Sprintf("/files/%s", fileId)
|
||||
req := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
// Helper to create multi-part file
|
||||
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
|
||||
body := new(strings.Builder)
|
||||
writer := multipart.NewWriter(body)
|
||||
file, _ := os.Open(filePath)
|
||||
defer file.Close()
|
||||
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
|
||||
io.Copy(part, file)
|
||||
|
||||
if purpose != "" {
|
||||
_ = writer.WriteField("purpose", purpose)
|
||||
}
|
||||
|
||||
writer.Close()
|
||||
return strings.NewReader(body.String()), writer
|
||||
}
|
||||
|
||||
// Helper to create test files
|
||||
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
|
||||
err := os.MkdirAll(option.UploadDir, 0750)
|
||||
if err != nil {
|
||||
|
||||
t.Fatalf("Error MKDIR: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Create(name)
|
||||
assert.NoError(t, err)
|
||||
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(name)
|
||||
os.RemoveAll(option.UploadDir)
|
||||
})
|
||||
return file
|
||||
}
|
||||
|
||||
func bodyToString(resp *http.Response, t *testing.T) string {
|
||||
return string(bodyToByteArray(resp, t))
|
||||
}
|
||||
|
||||
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return bodyBytes
|
||||
}
|
||||
|
||||
func responseToFile(t *testing.T, resp *http.Response) schema.File {
|
||||
var file schema.File
|
||||
responseToString := bodyToString(resp, t)
|
||||
|
||||
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode response: %s", err)
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
|
||||
func responseToListFile(t *testing.T, resp *http.Response) schema.ListFiles {
|
||||
var listFiles schema.ListFiles
|
||||
responseToString := bodyToString(resp, t)
|
||||
|
||||
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to decode response")
|
||||
}
|
||||
|
||||
return listFiles
|
||||
}
|
||||
@@ -79,49 +79,37 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Process input images (for img2img/inpainting)
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
src = processImageFile(input.File, appConfig.GeneratedContentDir)
|
||||
if src != "" {
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
}
|
||||
|
||||
fileData := []byte{}
|
||||
var err error
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
out, err := downloadFile(input.File)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading file:%w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
// Process multiple input images
|
||||
var inputImages []string
|
||||
if len(input.Files) > 0 {
|
||||
for _, file := range input.Files {
|
||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
||||
if processedFile != "" {
|
||||
inputImages = append(inputImages, processedFile)
|
||||
defer os.RemoveAll(processedFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
// Process reference images
|
||||
var refImages []string
|
||||
if len(input.RefImages) > 0 {
|
||||
for _, file := range input.RefImages {
|
||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
||||
if processedFile != "" {
|
||||
refImages = append(refImages, processedFile)
|
||||
defer os.RemoveAll(processedFile)
|
||||
}
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
@@ -202,7 +190,13 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||
// Use the first input image as src if available, otherwise use the original src
|
||||
inputSrc := src
|
||||
if len(inputImages) > 0 {
|
||||
inputSrc = inputImages[0]
|
||||
}
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -243,3 +237,51 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// processImageFile handles a single image file (URL or base64) and returns the path to the temporary file
|
||||
func processImageFile(file string, generatedContentDir string) string {
|
||||
fileData := []byte{}
|
||||
var err error
|
||||
|
||||
// check if file is an URL, if so download it and save it to a temporary file
|
||||
if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") {
|
||||
out, err := downloadFile(file)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed downloading file: %s", file)
|
||||
return ""
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed reading downloaded file: %s", out)
|
||||
return ""
|
||||
}
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(file)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed decoding base64 file")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(generatedContentDir, "b64")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed creating temporary file")
|
||||
return ""
|
||||
}
|
||||
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
log.Error().Err(err).Msg("Failed writing to temporary file")
|
||||
return ""
|
||||
}
|
||||
outputFile.Close()
|
||||
|
||||
return outputFile.Name()
|
||||
}
|
||||
|
||||
@@ -41,6 +41,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
}
|
||||
|
||||
router.Post("/v1/detection",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
|
||||
localai.DetectionEndpoint(cl, ml, appConfig))
|
||||
|
||||
router.Post("/tts",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
|
||||
@@ -54,38 +54,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
||||
app.Post("/completions", completionChain...)
|
||||
app.Post("/v1/engines/:model/completions", completionChain...)
|
||||
|
||||
// assistant
|
||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
// files
|
||||
app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
|
||||
// embeddings
|
||||
embeddingChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -11,6 +11,7 @@ async function promptDallE(input) {
|
||||
document.getElementById("input").disabled = true;
|
||||
|
||||
const model = document.getElementById("image-model").value;
|
||||
const size = document.getElementById("image-size").value;
|
||||
const response = await fetch("v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -21,7 +22,7 @@ async function promptDallE(input) {
|
||||
steps: 10,
|
||||
prompt: input,
|
||||
n: 1,
|
||||
size: "512x512",
|
||||
size: size,
|
||||
}),
|
||||
});
|
||||
const json = await response.json();
|
||||
@@ -48,4 +49,13 @@ async function promptDallE(input) {
|
||||
|
||||
document.getElementById("input").focus();
|
||||
document.getElementById("genimage").addEventListener("submit", genImage);
|
||||
|
||||
// Handle Enter key press in the prompt input
|
||||
document.getElementById("input").addEventListener("keypress", function(event) {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
genImage(event);
|
||||
}
|
||||
});
|
||||
|
||||
document.getElementById("loader").style.display = "none";
|
||||
|
||||
@@ -90,6 +90,14 @@
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-headphones mr-2"></i>Whisper
|
||||
</button>
|
||||
<button hx-post="browse/search/backends"
|
||||
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
|
||||
hx-target="#search-results"
|
||||
hx-vals='{"search": "object-detection"}'
|
||||
onclick="hidePagination()"
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-eye mr-2"></i>Object detection
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -115,6 +115,14 @@
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-headphones mr-2"></i>Audio transcription
|
||||
</button>
|
||||
<button hx-post="browse/search/models"
|
||||
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
|
||||
hx-target="#search-results"
|
||||
hx-vals='{"search": "object-detection"}'
|
||||
onclick="hidePagination()"
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-eye mr-2"></i>Object detection
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -91,6 +91,30 @@
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Size Selection -->
|
||||
<div class="mt-4">
|
||||
<label for="image-size" class="block text-sm font-medium text-gray-300 mb-2">
|
||||
<i class="fas fa-expand-arrows-alt mr-2"></i>Image Size:
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="image-size"
|
||||
value="256x256"
|
||||
placeholder="e.g., 256x256, 512x512, 1024x1024"
|
||||
class="bg-gray-900 text-white border border-gray-700 focus:border-blue-500 focus:ring focus:ring-blue-500 focus:ring-opacity-50 rounded-lg shadow-sm p-2.5 w-full max-w-xs transition-colors duration-200"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Submit Button -->
|
||||
<div class="mt-6">
|
||||
<button
|
||||
type="submit"
|
||||
class="w-full bg-gradient-to-r from-blue-600 to-indigo-600 hover:from-blue-700 hover:to-indigo-700 text-white font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-lg focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50"
|
||||
>
|
||||
<i class="fas fa-magic mr-2"></i>Generate Image
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<!-- Image Results Container -->
|
||||
|
||||
@@ -120,3 +120,20 @@ type SystemInformationResponse struct {
|
||||
Backends []string `json:"backends"`
|
||||
Models []SysInfoModel `json:"loaded_models"`
|
||||
}
|
||||
|
||||
type DetectionRequest struct {
|
||||
BasicModelRequest
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type DetectionResponse struct {
|
||||
Detections []Detection `json:"detections"`
|
||||
}
|
||||
|
||||
type Detection struct {
|
||||
X float32 `json:"x"`
|
||||
Y float32 `json:"y"`
|
||||
Width float32 `json:"width"`
|
||||
Height float32 `json:"height"`
|
||||
ClassName string `json:"class_name"`
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
functions "github.com/mudler/LocalAI/pkg/functions"
|
||||
)
|
||||
@@ -115,37 +114,6 @@ type OpenAIModel struct {
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type DeleteAssistantResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
// File represents the structure of a file object from the OpenAI API.
|
||||
type File struct {
|
||||
ID string `json:"id"` // Unique identifier for the file
|
||||
Object string `json:"object"` // Type of the object (e.g., "file")
|
||||
Bytes int `json:"bytes"` // Size of the file in bytes
|
||||
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
|
||||
Filename string `json:"filename"` // The name of the file
|
||||
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
||||
}
|
||||
|
||||
type ListFiles struct {
|
||||
Data []File
|
||||
Object string
|
||||
}
|
||||
|
||||
type AssistantFileRequest struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
type DeleteAssistantFileResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
type ImageGenerationResponseFormat string
|
||||
|
||||
type ChatCompletionResponseFormatType string
|
||||
@@ -173,6 +141,10 @@ type OpenAIRequest struct {
|
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"`
|
||||
// Multiple input images for img2img or inpainting
|
||||
Files []string `json:"files,omitempty"`
|
||||
// Reference images for models that support them (e.g., Flux Kontext)
|
||||
RefImages []string `json:"ref_images,omitempty"`
|
||||
//whisper/image
|
||||
ResponseFormat interface{} `json:"response_format,omitempty"`
|
||||
// image
|
||||
@@ -211,6 +183,10 @@ type OpenAIRequest struct {
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
|
||||
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
|
||||
|
||||
ReasoningEffort string `json:"reasoning_effort" yaml:"reasoning_effort"`
|
||||
|
||||
Metadata map[string]string `json:"metadata" yaml:"metadata"`
|
||||
}
|
||||
|
||||
type ModelsDataResponse struct {
|
||||
|
||||
@@ -24,6 +24,7 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend], s
|
||||
g.modelLoader.DeleteExternalBackend(op.GalleryElementName)
|
||||
} else {
|
||||
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
|
||||
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
|
||||
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback, true)
|
||||
if err == nil {
|
||||
err = gallery.RegisterBackends(g.appConfig.BackendsPath, g.modelLoader)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package startup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -13,38 +12,68 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func InstallExternalBackends(galleries []config.Gallery, backendPath string, downloadStatus func(string, string, string, float64), backends ...string) error {
|
||||
var errs error
|
||||
func InstallExternalBackends(galleries []config.Gallery, backendPath string, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get system state: %w", err)
|
||||
}
|
||||
for _, backend := range backends {
|
||||
uri := downloader.URI(backend)
|
||||
switch {
|
||||
case uri.LooksLikeOCI():
|
||||
name, err := uri.FilenameFromUrl()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get filename from URL: %w", err)
|
||||
}
|
||||
// strip extension if any
|
||||
name = strings.TrimSuffix(name, filepath.Ext(name))
|
||||
uri := downloader.URI(backend)
|
||||
switch {
|
||||
case uri.LooksLikeDir():
|
||||
if name == "" { // infer it from the path
|
||||
name = filepath.Base(backend)
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
Alias: alias,
|
||||
URI: backend,
|
||||
}, downloadStatus); err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
case uri.LooksLikeOCI() && !uri.LooksLikeOCIFile():
|
||||
if name == "" {
|
||||
return fmt.Errorf("specifying a name is required for OCI images")
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
Alias: alias,
|
||||
URI: backend,
|
||||
}, downloadStatus); err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
case uri.LooksLikeOCIFile():
|
||||
name, err := uri.FilenameFromUrl()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get filename from URL: %w", err)
|
||||
}
|
||||
// strip extension if any
|
||||
name = strings.TrimSuffix(name, filepath.Ext(name))
|
||||
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
URI: backend,
|
||||
}, downloadStatus); err != nil {
|
||||
errs = errors.Join(err, fmt.Errorf("error installing backend %s", backend))
|
||||
}
|
||||
default:
|
||||
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true)
|
||||
if err != nil {
|
||||
errs = errors.Join(err, fmt.Errorf("error installing backend %s", backend))
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
Alias: alias,
|
||||
URI: backend,
|
||||
}, downloadStatus); err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
default:
|
||||
if name != "" || alias != "" {
|
||||
return fmt.Errorf("specifying a name or alias is not supported for this backend")
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ type PromptTemplateData struct {
|
||||
Instruction string
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
ReasoningEffort string
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
type ChatMessageTemplateData struct {
|
||||
@@ -133,7 +135,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
|
||||
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
var messageData []ChatMessageTemplateData
|
||||
@@ -283,6 +285,8 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
ReasoningEffort: input.ReasoningEffort,
|
||||
Metadata: input.Metadata,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
|
||||
@@ -219,7 +219,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
@@ -232,7 +232,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
@@ -245,7 +245,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range jinjaTest {
|
||||
foo := jinjaTest[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ Specifying a `config-file` via CLI allows to declare models in a single file as
|
||||
chat: chat
|
||||
```
|
||||
|
||||
See also [chatbot-ui](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui) as an example on how to use config files.
|
||||
See also [chatbot-ui](https://github.com/mudler/LocalAI-examples/tree/main/chatbot-ui) as an example on how to use config files.
|
||||
|
||||
It is possible to specify a full URL or a short-hand URL to a YAML model configuration file and use it on start with local-ai, for example to use phi-2:
|
||||
|
||||
@@ -341,7 +341,7 @@ Below is an instruction that describes a task, paired with an input that provide
|
||||
|
||||
Instead of installing models manually, you can use the LocalAI API endpoints and a model definition to install programmatically via API models in runtime.
|
||||
|
||||
A curated collection of model files is in the [model-gallery](https://github.com/go-skynet/model-gallery) (work in progress!). The files of the model gallery are different from the model files used to configure LocalAI models. The model gallery files contains information about the model setup, and the files necessary to run the model locally.
|
||||
A curated collection of model files is in the [model-gallery](https://github.com/mudler/LocalAI/tree/master/gallery). The files of the model gallery are different from the model files used to configure LocalAI models. The model gallery files contains information about the model setup, and the files necessary to run the model locally.
|
||||
|
||||
To install for example `lunademo`, you can send a POST call to the `/models/apply` endpoint with the model definition url (`url`) and the name of the model should have in LocalAI (`name`, optional):
|
||||
|
||||
@@ -445,15 +445,17 @@ make -C backend/python/vllm
|
||||
When LocalAI runs in a container,
|
||||
there are additional environment variables available that modify the behavior of LocalAI on startup:
|
||||
|
||||
{{< table "table-responsive" >}}
|
||||
| Environment variable | Default | Description |
|
||||
|----------------------------|---------|------------------------------------------------------------------------------------------------------------|
|
||||
| `REBUILD` | `false` | Rebuild LocalAI on startup |
|
||||
| `BUILD_TYPE` | | Build type. Available: `cublas`, `openblas`, `clblas` |
|
||||
| `BUILD_TYPE` | | Build type. Available: `cublas`, `openblas`, `clblas`, `intel` (intel core), `sycl_f16`, `sycl_f32` (intel backends) |
|
||||
| `GO_TAGS` | | Go tags. Available: `stablediffusion` |
|
||||
| `HUGGINGFACEHUB_API_TOKEN` | | Special token for interacting with HuggingFace Inference API, required only when using the `langchain-huggingface` backend |
|
||||
| `EXTRA_BACKENDS` | | A space separated list of backends to prepare. For example `EXTRA_BACKENDS="backend/python/diffusers backend/python/transformers"` prepares the python environment on start |
|
||||
| `DISABLE_AUTODETECT` | `false` | Disable autodetect of CPU flagset on start |
|
||||
| `LLAMACPP_GRPC_SERVERS` | | A list of llama.cpp workers to distribute the workload. For example `LLAMACPP_GRPC_SERVERS="address1:port,address2:port"` |
|
||||
{{< /table >}}
|
||||
|
||||
Here is how to configure these variables:
|
||||
|
||||
@@ -471,12 +473,15 @@ You can control LocalAI with command line arguments, to specify a binding addres
|
||||
In the help text below, BASEPATH is the location that local-ai is being executed from
|
||||
|
||||
#### Global Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| -h, --help | | Show context-sensitive help. |
|
||||
| --log-level | info | Set the level of logs to output [error,warn,info,debug] | $LOCALAI_LOG_LEVEL |
|
||||
{{< /table >}}
|
||||
|
||||
#### Storage Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH |
|
||||
@@ -487,8 +492,10 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|
||||
| --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR |
|
||||
| --localai-config-dir-poll-interval | | Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to a time duration to poll the LocalAI Config Dir (example: 1m) | $LOCALAI_CONFIG_DIR_POLL_INTERVAL |
|
||||
| --models-config-file | STRING | YAML file containing a list of model backend configs | $LOCALAI_MODELS_CONFIG_FILE |
|
||||
{{< /table >}}
|
||||
|
||||
#### Models Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| --galleries | STRING | JSON list of galleries | $LOCALAI_GALLERIES |
|
||||
@@ -497,15 +504,19 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|
||||
| --preload-models | STRING | A List of models to apply in JSON at start |$LOCALAI_PRELOAD_MODELS |
|
||||
| --models | MODELS,... | A List of model configuration URLs to load | $LOCALAI_MODELS |
|
||||
| --preload-models-config | STRING | A List of models to apply at startup. Path to a YAML config file | $LOCALAI_PRELOAD_MODELS_CONFIG |
|
||||
{{< /table >}}
|
||||
|
||||
#### Performance Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| --f16 | | Enable GPU acceleration | $LOCALAI_F16 |
|
||||
| -t, --threads | 4 | Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested | $LOCALAI_THREADS |
|
||||
| --context-size | 512 | Default context size for models | $LOCALAI_CONTEXT_SIZE |
|
||||
{{< /table >}}
|
||||
|
||||
#### API Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| --address | ":8080" | Bind address for the API server | $LOCALAI_ADDRESS |
|
||||
@@ -516,8 +527,10 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|
||||
| --disable-welcome | | Disable welcome pages | $LOCALAI_DISABLE_WELCOME |
|
||||
| --disable-webui | false | Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface | $LOCALAI_DISABLE_WEBUI |
|
||||
| --machine-tag | | If not empty - put that string to Machine-Tag header in each response. Useful to track response from different machines using multiple P2P federated nodes | $LOCALAI_MACHINE_TAG |
|
||||
{{< /table >}}
|
||||
|
||||
#### Backend Flags
|
||||
{{< table "table-responsive" >}}
|
||||
| Parameter | Default | Description | Environment Variable |
|
||||
|-----------|---------|-------------|----------------------|
|
||||
| --parallel-requests | | Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm) | $LOCALAI_PARALLEL_REQUESTS |
|
||||
@@ -528,6 +541,7 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|
||||
| --watchdog-idle-timeout | 15m | Threshold beyond which an idle backend should be stopped | $LOCALAI_WATCHDOG_IDLE_TIMEOUT, $WATCHDOG_IDLE_TIMEOUT |
|
||||
| --enable-watchdog-busy | | Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout | $LOCALAI_WATCHDOG_BUSY |
|
||||
| --watchdog-busy-timeout | 5m | Threshold beyond which a busy backend should be stopped | $LOCALAI_WATCHDOG_BUSY_TIMEOUT |
|
||||
{{< /table >}}
|
||||
|
||||
### .env files
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ There is the availability of localai-webui and chatbot-ui in the examples sectio
|
||||
|
||||
### Does it work with AutoGPT?
|
||||
|
||||
Yes, see the [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/)!
|
||||
Yes, see the [examples](https://github.com/mudler/LocalAI-examples)!
|
||||
|
||||
### How can I troubleshoot when something is wrong?
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user