mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-06 12:43:04 -05:00
Compare commits
1 Commits
v3.4.0
...
llama_cpp/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f52776a1c |
330
.github/workflows/backend.yml
vendored
330
.github/workflows/backend.yml
vendored
@@ -87,18 +87,6 @@ jobs:
|
|||||||
backend: "diffusers"
|
backend: "diffusers"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'l4t'
|
|
||||||
cuda-major-version: "12"
|
|
||||||
cuda-minor-version: "0"
|
|
||||||
platforms: 'linux/arm64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-nvidia-l4t-diffusers'
|
|
||||||
runs-on: 'ubuntu-24.04-arm'
|
|
||||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
|
||||||
skip-drivers: 'true'
|
|
||||||
backend: "diffusers"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
# CUDA 11 additional backends
|
# CUDA 11 additional backends
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
@@ -325,7 +313,7 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-rocm-hipblas-transformers'
|
tag-suffix: '-gpu-rocm-hipblas-transformers'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "transformers"
|
backend: "transformers"
|
||||||
@@ -337,7 +325,7 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-rocm-hipblas-diffusers'
|
tag-suffix: '-gpu-rocm-hipblas-diffusers'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "diffusers"
|
backend: "diffusers"
|
||||||
@@ -350,7 +338,7 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-rocm-hipblas-kokoro'
|
tag-suffix: '-gpu-rocm-hipblas-kokoro'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "kokoro"
|
backend: "kokoro"
|
||||||
@@ -386,19 +374,31 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-rocm-hipblas-bark'
|
tag-suffix: '-gpu-rocm-hipblas-bark'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "bark"
|
backend: "bark"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
# sycl builds
|
# sycl builds
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-rerankers'
|
tag-suffix: '-gpu-intel-sycl-f32-rerankers'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "rerankers"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f16-rerankers'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
@@ -429,36 +429,60 @@ jobs:
|
|||||||
backend: "llama-cpp"
|
backend: "llama-cpp"
|
||||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-vllm'
|
tag-suffix: '-gpu-intel-sycl-f32-vllm'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "vllm"
|
backend: "vllm"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-transformers'
|
tag-suffix: '-gpu-intel-sycl-f16-vllm'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "vllm"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-transformers'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "transformers"
|
backend: "transformers"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-diffusers'
|
tag-suffix: '-gpu-intel-sycl-f16-transformers'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "transformers"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-diffusers'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
@@ -466,48 +490,96 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
# SYCL additional backends
|
# SYCL additional backends
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-kokoro'
|
tag-suffix: '-gpu-intel-sycl-f32-kokoro'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "kokoro"
|
backend: "kokoro"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-faster-whisper'
|
tag-suffix: '-gpu-intel-sycl-f16-kokoro'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "kokoro"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-faster-whisper'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "faster-whisper"
|
backend: "faster-whisper"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-coqui'
|
tag-suffix: '-gpu-intel-sycl-f16-faster-whisper'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "faster-whisper"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-coqui'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "coqui"
|
backend: "coqui"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-intel-bark'
|
tag-suffix: '-gpu-intel-sycl-f16-coqui'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "coqui"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f32-bark'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "bark"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-intel-sycl-f16-bark'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
@@ -525,7 +597,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "piper"
|
backend: "piper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
# bark-cpp
|
# bark-cpp
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -538,7 +610,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "bark-cpp"
|
backend: "bark-cpp"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -587,7 +659,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
@@ -599,7 +671,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
@@ -611,7 +683,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'sycl_f32'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -623,7 +695,7 @@ jobs:
|
|||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'sycl_f16'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -635,7 +707,7 @@ jobs:
|
|||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'vulkan'
|
- build-type: 'vulkan'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -647,7 +719,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
@@ -659,7 +731,7 @@ jobs:
|
|||||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
runs-on: 'ubuntu-24.04-arm'
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
backend: "stablediffusion-ggml"
|
backend: "stablediffusion-ggml"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
# whisper
|
# whisper
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -672,7 +744,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
@@ -684,7 +756,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
@@ -696,7 +768,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'sycl_f32'
|
- build-type: 'sycl_f32'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -708,7 +780,7 @@ jobs:
|
|||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'sycl_f16'
|
- build-type: 'sycl_f16'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -720,7 +792,7 @@ jobs:
|
|||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'vulkan'
|
- build-type: 'vulkan'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -732,7 +804,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
@@ -744,7 +816,7 @@ jobs:
|
|||||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
runs-on: 'ubuntu-24.04-arm'
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
- build-type: 'hipblas'
|
- build-type: 'hipblas'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -756,7 +828,7 @@ jobs:
|
|||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "whisper"
|
backend: "whisper"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
#silero-vad
|
#silero-vad
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -769,7 +841,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "silero-vad"
|
backend: "silero-vad"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
# local-store
|
# local-store
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -782,7 +854,7 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "local-store"
|
backend: "local-store"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
# huggingface
|
# huggingface
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -795,156 +867,8 @@ jobs:
|
|||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
backend: "huggingface"
|
backend: "huggingface"
|
||||||
dockerfile: "./backend/Dockerfile.golang"
|
dockerfile: "./backend/Dockerfile.go"
|
||||||
context: "./"
|
context: "./"
|
||||||
# rfdetr
|
|
||||||
- build-type: ''
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64,linux/arm64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-cpu-rfdetr'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "rfdetr"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'cublas'
|
|
||||||
cuda-major-version: "12"
|
|
||||||
cuda-minor-version: "0"
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-nvidia-cuda-12-rfdetr'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "rfdetr"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'cublas'
|
|
||||||
cuda-major-version: "11"
|
|
||||||
cuda-minor-version: "7"
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-nvidia-cuda-11-rfdetr'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "rfdetr"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'intel'
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-intel-rfdetr'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "rfdetr"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'cublas'
|
|
||||||
cuda-major-version: "12"
|
|
||||||
cuda-minor-version: "0"
|
|
||||||
platforms: 'linux/arm64'
|
|
||||||
skip-drivers: 'true'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-nvidia-l4t-arm64-rfdetr'
|
|
||||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
|
||||||
runs-on: 'ubuntu-24.04-arm'
|
|
||||||
backend: "rfdetr"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
# exllama2
|
|
||||||
- build-type: ''
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-cpu-exllama2'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "exllama2"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'cublas'
|
|
||||||
cuda-major-version: "12"
|
|
||||||
cuda-minor-version: "0"
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-nvidia-cuda-12-exllama2'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "exllama2"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'cublas'
|
|
||||||
cuda-major-version: "11"
|
|
||||||
cuda-minor-version: "7"
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-nvidia-cuda-11-exllama2'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "exllama2"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'intel'
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-intel-exllama2'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "exllama2"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
- build-type: 'hipblas'
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
skip-drivers: 'true'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-gpu-hipblas-exllama2'
|
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.1"
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
backend: "exllama2"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
# runs out of space on the runner
|
|
||||||
# - build-type: 'hipblas'
|
|
||||||
# cuda-major-version: ""
|
|
||||||
# cuda-minor-version: ""
|
|
||||||
# platforms: 'linux/amd64'
|
|
||||||
# tag-latest: 'auto'
|
|
||||||
# tag-suffix: '-gpu-hipblas-rfdetr'
|
|
||||||
# base-image: "rocm/dev-ubuntu-22.04:6.1"
|
|
||||||
# runs-on: 'ubuntu-latest'
|
|
||||||
# skip-drivers: 'false'
|
|
||||||
# backend: "rfdetr"
|
|
||||||
# dockerfile: "./backend/Dockerfile.python"
|
|
||||||
# context: "./backend"
|
|
||||||
# kitten-tts
|
|
||||||
- build-type: ''
|
|
||||||
cuda-major-version: ""
|
|
||||||
cuda-minor-version: ""
|
|
||||||
platforms: 'linux/amd64,linux/arm64'
|
|
||||||
tag-latest: 'auto'
|
|
||||||
tag-suffix: '-kitten-tts'
|
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
base-image: "ubuntu:22.04"
|
|
||||||
skip-drivers: 'false'
|
|
||||||
backend: "kitten-tts"
|
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
|
||||||
context: "./backend"
|
|
||||||
llama-cpp-darwin:
|
llama-cpp-darwin:
|
||||||
runs-on: macOS-14
|
runs-on: macOS-14
|
||||||
strategy:
|
strategy:
|
||||||
@@ -980,7 +904,6 @@ jobs:
|
|||||||
path: build/llama-cpp.tar
|
path: build/llama-cpp.tar
|
||||||
llama-cpp-darwin-publish:
|
llama-cpp-darwin-publish:
|
||||||
needs: llama-cpp-darwin
|
needs: llama-cpp-darwin
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Download llama-cpp.tar
|
- name: Download llama-cpp.tar
|
||||||
@@ -1069,7 +992,6 @@ jobs:
|
|||||||
name: llama-cpp-tar-x86
|
name: llama-cpp-tar-x86
|
||||||
path: build/llama-cpp.tar
|
path: build/llama-cpp.tar
|
||||||
llama-cpp-darwin-x86-publish:
|
llama-cpp-darwin-x86-publish:
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
needs: llama-cpp-darwin-x86
|
needs: llama-cpp-darwin-x86
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
@@ -1123,4 +1045,4 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
|
for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
|
||||||
crane push llama-cpp.tar $tag
|
crane push llama-cpp.tar $tag
|
||||||
done
|
done
|
||||||
2
.github/workflows/bump_deps.yaml
vendored
2
.github/workflows/bump_deps.yaml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
variable: "BARKCPP_VERSION"
|
variable: "BARKCPP_VERSION"
|
||||||
branch: "main"
|
branch: "main"
|
||||||
file: "Makefile"
|
file: "Makefile"
|
||||||
- repository: "leejet/stable-diffusion.cpp"
|
- repository: "richiejp/stable-diffusion.cpp"
|
||||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
file: "backend/go/stablediffusion-ggml/Makefile"
|
file: "backend/go/stablediffusion-ggml/Makefile"
|
||||||
|
|||||||
6
.github/workflows/image-pr.yml
vendored
6
.github/workflows/image-pr.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-gpu-nvidia-cuda-12'
|
tag-suffix: '-gpu-nvidia-cuda12'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
makeflags: "--jobs=3 --output-sync=target"
|
makeflags: "--jobs=3 --output-sync=target"
|
||||||
@@ -51,12 +51,12 @@ jobs:
|
|||||||
grpc-base-image: "ubuntu:22.04"
|
grpc-base-image: "ubuntu:22.04"
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
makeflags: "--jobs=3 --output-sync=target"
|
makeflags: "--jobs=3 --output-sync=target"
|
||||||
- build-type: 'sycl'
|
- build-type: 'sycl_f16'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
grpc-base-image: "ubuntu:22.04"
|
grpc-base-image: "ubuntu:22.04"
|
||||||
tag-suffix: 'sycl'
|
tag-suffix: 'sycl-f16'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
makeflags: "--jobs=3 --output-sync=target"
|
makeflags: "--jobs=3 --output-sync=target"
|
||||||
- build-type: 'vulkan'
|
- build-type: 'vulkan'
|
||||||
|
|||||||
21
.github/workflows/image.yml
vendored
21
.github/workflows/image.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
|||||||
cuda-minor-version: "7"
|
cuda-minor-version: "7"
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-nvidia-cuda-11'
|
tag-suffix: '-gpu-nvidia-cuda11'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
makeflags: "--jobs=4 --output-sync=target"
|
makeflags: "--jobs=4 --output-sync=target"
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-nvidia-cuda-12'
|
tag-suffix: '-gpu-nvidia-cuda12'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
@@ -103,21 +103,30 @@ jobs:
|
|||||||
- build-type: 'vulkan'
|
- build-type: 'vulkan'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-vulkan'
|
tag-suffix: '-vulkan'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
makeflags: "--jobs=4 --output-sync=target"
|
makeflags: "--jobs=4 --output-sync=target"
|
||||||
aio: "-aio-gpu-vulkan"
|
aio: "-aio-gpu-vulkan"
|
||||||
- build-type: 'intel'
|
- build-type: 'sycl_f16'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
grpc-base-image: "ubuntu:22.04"
|
grpc-base-image: "ubuntu:22.04"
|
||||||
tag-suffix: '-gpu-intel'
|
tag-suffix: '-gpu-intel-f16'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
makeflags: "--jobs=3 --output-sync=target"
|
makeflags: "--jobs=3 --output-sync=target"
|
||||||
aio: "-aio-gpu-intel"
|
aio: "-aio-gpu-intel-f16"
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
base-image: "quay.io/go-skynet/intel-oneapi-base:latest"
|
||||||
|
grpc-base-image: "ubuntu:22.04"
|
||||||
|
tag-suffix: '-gpu-intel-f32'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
makeflags: "--jobs=3 --output-sync=target"
|
||||||
|
aio: "-aio-gpu-intel-f32"
|
||||||
|
|
||||||
gh-runner:
|
gh-runner:
|
||||||
uses: ./.github/workflows/image_build.yml
|
uses: ./.github/workflows/image_build.yml
|
||||||
|
|||||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -23,20 +23,6 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.21.x']
|
go-version: ['1.21.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Free Disk Space (Ubuntu)
|
|
||||||
uses: jlumbroso/free-disk-space@main
|
|
||||||
with:
|
|
||||||
# this might remove tools that are actually needed,
|
|
||||||
# if set to "true" but frees about 6 GB
|
|
||||||
tool-cache: true
|
|
||||||
# all of these default to true, but feel free to set to
|
|
||||||
# "false" if necessary for your workflow
|
|
||||||
android: true
|
|
||||||
dotnet: true
|
|
||||||
haskell: true
|
|
||||||
large-packages: true
|
|
||||||
docker-images: true
|
|
||||||
swap-storage: true
|
|
||||||
- name: Release space from worker
|
- name: Release space from worker
|
||||||
run: |
|
run: |
|
||||||
echo "Listing top largest packages"
|
echo "Listing top largest packages"
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ prepare-sources
|
|||||||
/backends
|
/backends
|
||||||
/backend-images
|
/backend-images
|
||||||
/result.yaml
|
/result.yaml
|
||||||
protoc
|
|
||||||
|
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
|
|||||||
10
Dockerfile
10
Dockerfile
@@ -9,7 +9,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
ca-certificates curl wget espeak-ng libgomp1 \
|
ca-certificates curl wget espeak-ng libgomp1 \
|
||||||
python3 python-is-python3 ffmpeg libopenblas-base libopenblas-dev && \
|
python3 python-is-python3 ffmpeg && \
|
||||||
apt-get clean && \
|
apt-get clean && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
@@ -72,12 +72,6 @@ RUN <<EOT bash
|
|||||||
fi
|
fi
|
||||||
EOT
|
EOT
|
||||||
|
|
||||||
RUN <<EOT bash
|
|
||||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
|
||||||
echo "nvidia-l4t" > /run/localai/capability
|
|
||||||
fi
|
|
||||||
EOT
|
|
||||||
|
|
||||||
# If we are building with clblas support, we need the libraries for the builds
|
# If we are building with clblas support, we need the libraries for the builds
|
||||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
@@ -100,8 +94,6 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
|||||||
ldconfig \
|
ldconfig \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
|
|
||||||
|
|
||||||
# Cuda
|
# Cuda
|
||||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||||
|
|
||||||
|
|||||||
178
Makefile
178
Makefile
@@ -5,6 +5,8 @@ BINARY_NAME=local-ai
|
|||||||
|
|
||||||
GORELEASER?=
|
GORELEASER?=
|
||||||
|
|
||||||
|
ONEAPI_VERSION?=2025.2
|
||||||
|
|
||||||
export BUILD_TYPE?=
|
export BUILD_TYPE?=
|
||||||
|
|
||||||
GO_TAGS?=
|
GO_TAGS?=
|
||||||
@@ -132,9 +134,6 @@ test: test-models/testmodel.ggml protogen-go
|
|||||||
$(MAKE) test-tts
|
$(MAKE) test-tts
|
||||||
$(MAKE) test-stablediffusion
|
$(MAKE) test-stablediffusion
|
||||||
|
|
||||||
backends/diffusers: docker-build-diffusers docker-save-diffusers build
|
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)"
|
|
||||||
|
|
||||||
backends/llama-cpp: docker-build-llama-cpp docker-save-llama-cpp build
|
backends/llama-cpp: docker-build-llama-cpp docker-save-llama-cpp build
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
||||||
|
|
||||||
@@ -146,7 +145,7 @@ backends/stablediffusion-ggml: docker-build-stablediffusion-ggml docker-save-sta
|
|||||||
|
|
||||||
backends/whisper: docker-build-whisper docker-save-whisper build
|
backends/whisper: docker-build-whisper docker-save-whisper build
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/whisper.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/whisper.tar)"
|
||||||
|
|
||||||
backends/silero-vad: docker-build-silero-vad docker-save-silero-vad build
|
backends/silero-vad: docker-build-silero-vad docker-save-silero-vad build
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/silero-vad.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/silero-vad.tar)"
|
||||||
|
|
||||||
@@ -156,15 +155,6 @@ backends/local-store: docker-build-local-store docker-save-local-store build
|
|||||||
backends/huggingface: docker-build-huggingface docker-save-huggingface build
|
backends/huggingface: docker-build-huggingface docker-save-huggingface build
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/huggingface.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/huggingface.tar)"
|
||||||
|
|
||||||
backends/rfdetr: docker-build-rfdetr docker-save-rfdetr build
|
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/rfdetr.tar)"
|
|
||||||
|
|
||||||
backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
|
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kitten-tts.tar)"
|
|
||||||
|
|
||||||
backends/kokoro: docker-build-kokoro docker-save-kokoro build
|
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
## AIO tests
|
## AIO tests
|
||||||
########################################################
|
########################################################
|
||||||
@@ -252,7 +242,10 @@ help: ## Show this help.
|
|||||||
########################################################
|
########################################################
|
||||||
|
|
||||||
.PHONY: protogen
|
.PHONY: protogen
|
||||||
protogen: protogen-go
|
protogen: protogen-go protogen-python
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean: protogen-go-clean protogen-python-clean
|
||||||
|
|
||||||
protoc:
|
protoc:
|
||||||
@OS_NAME=$$(uname -s | tr '[:upper:]' '[:lower:]'); \
|
@OS_NAME=$$(uname -s | tr '[:upper:]' '[:lower:]'); \
|
||||||
@@ -297,6 +290,93 @@ protogen-go-clean:
|
|||||||
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
||||||
$(RM) bin/*
|
$(RM) bin/*
|
||||||
|
|
||||||
|
.PHONY: protogen-python
|
||||||
|
protogen-python: bark-protogen coqui-protogen chatterbox-protogen diffusers-protogen exllama2-protogen rerankers-protogen transformers-protogen kokoro-protogen vllm-protogen faster-whisper-protogen
|
||||||
|
|
||||||
|
.PHONY: protogen-python-clean
|
||||||
|
protogen-python-clean: bark-protogen-clean coqui-protogen-clean chatterbox-protogen-clean diffusers-protogen-clean exllama2-protogen-clean rerankers-protogen-clean transformers-protogen-clean kokoro-protogen-clean vllm-protogen-clean faster-whisper-protogen-clean
|
||||||
|
|
||||||
|
.PHONY: bark-protogen
|
||||||
|
bark-protogen:
|
||||||
|
$(MAKE) -C backend/python/bark protogen
|
||||||
|
|
||||||
|
.PHONY: bark-protogen-clean
|
||||||
|
bark-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/bark protogen-clean
|
||||||
|
|
||||||
|
.PHONY: coqui-protogen
|
||||||
|
coqui-protogen:
|
||||||
|
$(MAKE) -C backend/python/coqui protogen
|
||||||
|
|
||||||
|
.PHONY: coqui-protogen-clean
|
||||||
|
coqui-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/coqui protogen-clean
|
||||||
|
|
||||||
|
.PHONY: diffusers-protogen
|
||||||
|
diffusers-protogen:
|
||||||
|
$(MAKE) -C backend/python/diffusers protogen
|
||||||
|
|
||||||
|
.PHONY: chatterbox-protogen
|
||||||
|
chatterbox-protogen:
|
||||||
|
$(MAKE) -C backend/python/chatterbox protogen
|
||||||
|
|
||||||
|
.PHONY: diffusers-protogen-clean
|
||||||
|
diffusers-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/diffusers protogen-clean
|
||||||
|
|
||||||
|
.PHONY: chatterbox-protogen-clean
|
||||||
|
chatterbox-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/chatterbox protogen-clean
|
||||||
|
|
||||||
|
.PHONY: faster-whisper-protogen
|
||||||
|
faster-whisper-protogen:
|
||||||
|
$(MAKE) -C backend/python/faster-whisper protogen
|
||||||
|
|
||||||
|
.PHONY: faster-whisper-protogen-clean
|
||||||
|
faster-whisper-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/faster-whisper protogen-clean
|
||||||
|
|
||||||
|
.PHONY: exllama2-protogen
|
||||||
|
exllama2-protogen:
|
||||||
|
$(MAKE) -C backend/python/exllama2 protogen
|
||||||
|
|
||||||
|
.PHONY: exllama2-protogen-clean
|
||||||
|
exllama2-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/exllama2 protogen-clean
|
||||||
|
|
||||||
|
.PHONY: rerankers-protogen
|
||||||
|
rerankers-protogen:
|
||||||
|
$(MAKE) -C backend/python/rerankers protogen
|
||||||
|
|
||||||
|
.PHONY: rerankers-protogen-clean
|
||||||
|
rerankers-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/rerankers protogen-clean
|
||||||
|
|
||||||
|
.PHONY: transformers-protogen
|
||||||
|
transformers-protogen:
|
||||||
|
$(MAKE) -C backend/python/transformers protogen
|
||||||
|
|
||||||
|
.PHONY: transformers-protogen-clean
|
||||||
|
transformers-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/transformers protogen-clean
|
||||||
|
|
||||||
|
.PHONY: kokoro-protogen
|
||||||
|
kokoro-protogen:
|
||||||
|
$(MAKE) -C backend/python/kokoro protogen
|
||||||
|
|
||||||
|
.PHONY: kokoro-protogen-clean
|
||||||
|
kokoro-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/kokoro protogen-clean
|
||||||
|
|
||||||
|
.PHONY: vllm-protogen
|
||||||
|
vllm-protogen:
|
||||||
|
$(MAKE) -C backend/python/vllm protogen
|
||||||
|
|
||||||
|
.PHONY: vllm-protogen-clean
|
||||||
|
vllm-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/vllm protogen-clean
|
||||||
|
|
||||||
|
|
||||||
prepare-test-extra: protogen-python
|
prepare-test-extra: protogen-python
|
||||||
$(MAKE) -C backend/python/transformers
|
$(MAKE) -C backend/python/transformers
|
||||||
$(MAKE) -C backend/python/diffusers
|
$(MAKE) -C backend/python/diffusers
|
||||||
@@ -332,7 +412,7 @@ docker-cuda11:
|
|||||||
--build-arg GO_TAGS="$(GO_TAGS)" \
|
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||||
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||||
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
||||||
-t $(DOCKER_IMAGE)-cuda-11 .
|
-t $(DOCKER_IMAGE)-cuda11 .
|
||||||
|
|
||||||
docker-aio:
|
docker-aio:
|
||||||
@echo "Building AIO image with base $(BASE_IMAGE) as $(DOCKER_AIO_IMAGE)"
|
@echo "Building AIO image with base $(BASE_IMAGE) as $(DOCKER_AIO_IMAGE)"
|
||||||
@@ -347,11 +427,19 @@ docker-aio-all:
|
|||||||
|
|
||||||
docker-image-intel:
|
docker-image-intel:
|
||||||
docker build \
|
docker build \
|
||||||
--build-arg BASE_IMAGE=quay.io/go-skynet/intel-oneapi-base:latest \
|
--build-arg BASE_IMAGE=intel/oneapi-basekit:${ONEAPI_VERSION}.0-0-devel-ubuntu24.04 \
|
||||||
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
--build-arg GO_TAGS="$(GO_TAGS)" \
|
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||||
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||||
--build-arg BUILD_TYPE=intel -t $(DOCKER_IMAGE) .
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
|
|
||||||
|
docker-image-intel-xpu:
|
||||||
|
docker build \
|
||||||
|
--build-arg BASE_IMAGE=intel/oneapi-basekit:${ONEAPI_VERSION}.0-0-devel-ubuntu22.04 \
|
||||||
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
|
--build-arg GO_TAGS="$(GO_TAGS)" \
|
||||||
|
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
|
||||||
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
## Backends
|
## Backends
|
||||||
@@ -361,37 +449,19 @@ backend-images:
|
|||||||
mkdir -p backend-images
|
mkdir -p backend-images
|
||||||
|
|
||||||
docker-build-llama-cpp:
|
docker-build-llama-cpp:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:llama-cpp -f backend/Dockerfile.llama-cpp .
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg IMAGE_BASE=$(IMAGE_BASE) -t local-ai-backend:llama-cpp -f backend/Dockerfile.llama-cpp .
|
||||||
|
|
||||||
docker-build-bark-cpp:
|
docker-build-bark-cpp:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark-cpp -f backend/Dockerfile.golang --build-arg BACKEND=bark-cpp .
|
docker build -t local-ai-backend:bark-cpp -f backend/Dockerfile.go --build-arg BACKEND=bark-cpp .
|
||||||
|
|
||||||
docker-build-piper:
|
docker-build-piper:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:piper -f backend/Dockerfile.golang --build-arg BACKEND=piper .
|
docker build -t local-ai-backend:piper -f backend/Dockerfile.go --build-arg BACKEND=piper .
|
||||||
|
|
||||||
docker-build-local-store:
|
docker-build-local-store:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:local-store -f backend/Dockerfile.golang --build-arg BACKEND=local-store .
|
docker build -t local-ai-backend:local-store -f backend/Dockerfile.go --build-arg BACKEND=local-store .
|
||||||
|
|
||||||
docker-build-huggingface:
|
docker-build-huggingface:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:huggingface -f backend/Dockerfile.golang --build-arg BACKEND=huggingface .
|
docker build -t local-ai-backend:huggingface -f backend/Dockerfile.go --build-arg BACKEND=huggingface .
|
||||||
|
|
||||||
docker-build-rfdetr:
|
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:rfdetr -f backend/Dockerfile.python --build-arg BACKEND=rfdetr ./backend
|
|
||||||
|
|
||||||
docker-build-kitten-tts:
|
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kitten-tts -f backend/Dockerfile.python --build-arg BACKEND=kitten-tts ./backend
|
|
||||||
|
|
||||||
docker-save-kitten-tts: backend-images
|
|
||||||
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
|
||||||
|
|
||||||
docker-build-kokoro:
|
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
|
||||||
|
|
||||||
docker-save-kokoro: backend-images
|
|
||||||
docker save local-ai-backend:kokoro -o backend-images/kokoro.tar
|
|
||||||
|
|
||||||
docker-save-rfdetr: backend-images
|
|
||||||
docker save local-ai-backend:rfdetr -o backend-images/rfdetr.tar
|
|
||||||
|
|
||||||
docker-save-huggingface: backend-images
|
docker-save-huggingface: backend-images
|
||||||
docker save local-ai-backend:huggingface -o backend-images/huggingface.tar
|
docker save local-ai-backend:huggingface -o backend-images/huggingface.tar
|
||||||
@@ -400,7 +470,7 @@ docker-save-local-store: backend-images
|
|||||||
docker save local-ai-backend:local-store -o backend-images/local-store.tar
|
docker save local-ai-backend:local-store -o backend-images/local-store.tar
|
||||||
|
|
||||||
docker-build-silero-vad:
|
docker-build-silero-vad:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:silero-vad -f backend/Dockerfile.golang --build-arg BACKEND=silero-vad .
|
docker build -t local-ai-backend:silero-vad -f backend/Dockerfile.go --build-arg BACKEND=silero-vad .
|
||||||
|
|
||||||
docker-save-silero-vad: backend-images
|
docker-save-silero-vad: backend-images
|
||||||
docker save local-ai-backend:silero-vad -o backend-images/silero-vad.tar
|
docker save local-ai-backend:silero-vad -o backend-images/silero-vad.tar
|
||||||
@@ -415,46 +485,46 @@ docker-save-bark-cpp: backend-images
|
|||||||
docker save local-ai-backend:bark-cpp -o backend-images/bark-cpp.tar
|
docker save local-ai-backend:bark-cpp -o backend-images/bark-cpp.tar
|
||||||
|
|
||||||
docker-build-stablediffusion-ggml:
|
docker-build-stablediffusion-ggml:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:stablediffusion-ggml -f backend/Dockerfile.golang --build-arg BACKEND=stablediffusion-ggml .
|
docker build -t local-ai-backend:stablediffusion-ggml -f backend/Dockerfile.go --build-arg BACKEND=stablediffusion-ggml .
|
||||||
|
|
||||||
docker-save-stablediffusion-ggml: backend-images
|
docker-save-stablediffusion-ggml: backend-images
|
||||||
docker save local-ai-backend:stablediffusion-ggml -o backend-images/stablediffusion-ggml.tar
|
docker save local-ai-backend:stablediffusion-ggml -o backend-images/stablediffusion-ggml.tar
|
||||||
|
|
||||||
docker-build-rerankers:
|
docker-build-rerankers:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:rerankers -f backend/Dockerfile.python --build-arg BACKEND=rerankers .
|
docker build -t local-ai-backend:rerankers -f backend/Dockerfile.python --build-arg BACKEND=rerankers .
|
||||||
|
|
||||||
docker-build-vllm:
|
docker-build-vllm:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:vllm -f backend/Dockerfile.python --build-arg BACKEND=vllm .
|
docker build -t local-ai-backend:vllm -f backend/Dockerfile.python --build-arg BACKEND=vllm .
|
||||||
|
|
||||||
docker-build-transformers:
|
docker-build-transformers:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:transformers -f backend/Dockerfile.python --build-arg BACKEND=transformers .
|
docker build -t local-ai-backend:transformers -f backend/Dockerfile.python --build-arg BACKEND=transformers .
|
||||||
|
|
||||||
docker-build-diffusers:
|
docker-build-diffusers:
|
||||||
docker build --progress=plain --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:diffusers -f backend/Dockerfile.python --build-arg BACKEND=diffusers ./backend
|
docker build -t local-ai-backend:diffusers -f backend/Dockerfile.python --build-arg BACKEND=diffusers .
|
||||||
|
|
||||||
docker-save-diffusers: backend-images
|
docker-build-kokoro:
|
||||||
docker save local-ai-backend:diffusers -o backend-images/diffusers.tar
|
docker build -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro .
|
||||||
|
|
||||||
docker-build-whisper:
|
docker-build-whisper:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.go --build-arg BACKEND=whisper .
|
||||||
|
|
||||||
docker-save-whisper: backend-images
|
docker-save-whisper: backend-images
|
||||||
docker save local-ai-backend:whisper -o backend-images/whisper.tar
|
docker save local-ai-backend:whisper -o backend-images/whisper.tar
|
||||||
|
|
||||||
docker-build-faster-whisper:
|
docker-build-faster-whisper:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:faster-whisper -f backend/Dockerfile.python --build-arg BACKEND=faster-whisper .
|
docker build -t local-ai-backend:faster-whisper -f backend/Dockerfile.python --build-arg BACKEND=faster-whisper .
|
||||||
|
|
||||||
docker-build-coqui:
|
docker-build-coqui:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:coqui -f backend/Dockerfile.python --build-arg BACKEND=coqui .
|
docker build -t local-ai-backend:coqui -f backend/Dockerfile.python --build-arg BACKEND=coqui .
|
||||||
|
|
||||||
docker-build-bark:
|
docker-build-bark:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
|
docker build -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
|
||||||
|
|
||||||
docker-build-chatterbox:
|
docker-build-chatterbox:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
|
docker build -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
|
||||||
|
|
||||||
docker-build-exllama2:
|
docker-build-exllama2:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
|
docker build -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
|
||||||
|
|
||||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-exllama2
|
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-exllama2
|
||||||
|
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -140,7 +140,11 @@ docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri
|
|||||||
### Intel GPU Images (oneAPI):
|
### Intel GPU Images (oneAPI):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
# Intel GPU with FP16 support
|
||||||
|
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel-f16
|
||||||
|
|
||||||
|
# Intel GPU with FP32 support
|
||||||
|
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel-f32
|
||||||
```
|
```
|
||||||
|
|
||||||
### Vulkan GPU Images:
|
### Vulkan GPU Images:
|
||||||
@@ -162,7 +166,7 @@ docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-ai
|
|||||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-aio-gpu-nvidia-cuda-11
|
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-aio-gpu-nvidia-cuda-11
|
||||||
|
|
||||||
# Intel GPU version
|
# Intel GPU version
|
||||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-gpu-intel
|
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-gpu-intel-f16
|
||||||
|
|
||||||
# AMD GPU version
|
# AMD GPU version
|
||||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-aio-gpu-hipblas
|
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-aio-gpu-hipblas
|
||||||
@@ -185,14 +189,10 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
|||||||
local-ai run oci://localai/phi-2:latest
|
local-ai run oci://localai/phi-2:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
|
||||||
|
|
||||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
|
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
|
||||||
|
|
||||||
## 📰 Latest project news
|
## 📰 Latest project news
|
||||||
|
|
||||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
|
||||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
|
||||||
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
||||||
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
||||||
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
|
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
|
||||||
@@ -225,7 +225,6 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
|||||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||||
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
||||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
|
||||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||||
- [Agentic capabilities](https://github.com/mudler/LocalAGI)
|
- [Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||||
|
|||||||
@@ -96,6 +96,17 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
|||||||
ldconfig \
|
ldconfig \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
|
# Intel oneAPI requirements
|
||||||
|
RUN <<EOT bash
|
||||||
|
if [[ "${BUILD_TYPE}" == sycl* ]] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
intel-oneapi-runtime-libs && \
|
||||||
|
apt-get clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
fi
|
||||||
|
EOT
|
||||||
|
|
||||||
# Install Go
|
# Install Go
|
||||||
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
|
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
|
||||||
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin:/usr/local/bin
|
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin:/usr/local/bin
|
||||||
@@ -11,6 +11,7 @@ ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
|||||||
ARG GRPC_VERSION=v1.65.0
|
ARG GRPC_VERSION=v1.65.0
|
||||||
ARG CMAKE_FROM_SOURCE=false
|
ARG CMAKE_FROM_SOURCE=false
|
||||||
ARG CMAKE_VERSION=3.26.4
|
ARG CMAKE_VERSION=3.26.4
|
||||||
|
ARG PROTOBUF_VERSION=v21.12
|
||||||
|
|
||||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||||
|
|
||||||
@@ -49,6 +50,14 @@ RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shall
|
|||||||
make install && \
|
make install && \
|
||||||
rm -rf /build
|
rm -rf /build
|
||||||
|
|
||||||
|
RUN git clone --recurse-submodules --branch ${PROTOBUF_VERSION} https://github.com/protocolbuffers/protobuf.git && \
|
||||||
|
mkdir -p /build/protobuf/build && \
|
||||||
|
cd /build/protobuf/build && \
|
||||||
|
cmake -Dprotobuf_BUILD_SHARED_LIBS=ON -Dprotobuf_BUILD_TESTS=OFF .. && \
|
||||||
|
make && \
|
||||||
|
make install && \
|
||||||
|
rm -rf /build
|
||||||
|
|
||||||
FROM ${BASE_IMAGE} AS builder
|
FROM ${BASE_IMAGE} AS builder
|
||||||
ARG BACKEND=rerankers
|
ARG BACKEND=rerankers
|
||||||
ARG BUILD_TYPE
|
ARG BUILD_TYPE
|
||||||
@@ -180,21 +189,9 @@ COPY --from=grpc /opt/grpc /usr/local
|
|||||||
|
|
||||||
COPY . /LocalAI
|
COPY . /LocalAI
|
||||||
|
|
||||||
## Otherwise just run the normal build
|
RUN make -C /LocalAI/backend/cpp/llama-cpp llama-cpp
|
||||||
RUN <<EOT bash
|
RUN make -C /LocalAI/backend/cpp/llama-cpp llama-cpp-grpc
|
||||||
if [ "${TARGETARCH}" = "arm64" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
RUN make -C /LocalAI/backend/cpp/llama-cpp llama-cpp-rpc-server
|
||||||
cd /LocalAI/backend/cpp/llama-cpp && make llama-cpp-fallback && \
|
|
||||||
make llama-cpp-grpc && make llama-cpp-rpc-server; \
|
|
||||||
else \
|
|
||||||
cd /LocalAI/backend/cpp/llama-cpp && make llama-cpp-avx && \
|
|
||||||
make llama-cpp-avx2 && \
|
|
||||||
make llama-cpp-avx512 && \
|
|
||||||
make llama-cpp-fallback && \
|
|
||||||
make llama-cpp-grpc && \
|
|
||||||
make llama-cpp-rpc-server; \
|
|
||||||
fi
|
|
||||||
EOT
|
|
||||||
|
|
||||||
|
|
||||||
# Copy libraries using a script to handle architecture differences
|
# Copy libraries using a script to handle architecture differences
|
||||||
RUN make -C /LocalAI/backend/cpp/llama-cpp package
|
RUN make -C /LocalAI/backend/cpp/llama-cpp package
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ service Backend {
|
|||||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||||
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
|
||||||
|
|
||||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||||
@@ -305,9 +304,6 @@ message GenerateImageRequest {
|
|||||||
// Diffusers
|
// Diffusers
|
||||||
string EnableParameters = 10;
|
string EnableParameters = 10;
|
||||||
int32 CLIPSkip = 11;
|
int32 CLIPSkip = 11;
|
||||||
|
|
||||||
// Reference images for models that support them (e.g., Flux Kontext)
|
|
||||||
repeated string ref_images = 12;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateVideoRequest {
|
message GenerateVideoRequest {
|
||||||
@@ -380,20 +376,3 @@ message Message {
|
|||||||
string role = 1;
|
string role = 1;
|
||||||
string content = 2;
|
string content = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DetectOptions {
|
|
||||||
string src = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Detection {
|
|
||||||
float x = 1;
|
|
||||||
float y = 2;
|
|
||||||
float width = 3;
|
|
||||||
float height = 4;
|
|
||||||
float confidence = 5;
|
|
||||||
string class_name = 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
message DetectResponse {
|
|
||||||
repeated Detection Detections = 1;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
include_directories("${HOMEBREW_DEFAULT_PREFIX}/include")
|
include_directories("${HOMEBREW_DEFAULT_PREFIX}/include")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(Protobuf_USE_STATIC_LIBS OFF)
|
||||||
|
set(gRPC_USE_STATIC_LIBS OFF)
|
||||||
find_package(absl CONFIG REQUIRED)
|
find_package(absl CONFIG REQUIRED)
|
||||||
find_package(Protobuf CONFIG REQUIRED)
|
find_package(Protobuf CONFIG REQUIRED)
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
find_package(gRPC CONFIG REQUIRED)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=be48528b068111304e4a0bb82c028558b5705f05
|
LLAMA_VERSION?=acd6cb1c41676f6bbb25c2a76fa5abeb1719301e
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
@@ -7,10 +7,9 @@ BUILD_TYPE?=
|
|||||||
NATIVE?=false
|
NATIVE?=false
|
||||||
ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh
|
ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh
|
||||||
TARGET?=--target grpc-server
|
TARGET?=--target grpc-server
|
||||||
JOBS?=$(shell nproc)
|
|
||||||
|
|
||||||
# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
|
# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=ON -DLLAMA_CURL=OFF -DGGML_CPU_ALL_VARIANTS=ON -DGGML_BACKEND_DL=ON
|
||||||
|
|
||||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
ifeq ($(NATIVE),false)
|
ifeq ($(NATIVE),false)
|
||||||
@@ -26,7 +25,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
|||||||
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
else ifeq ($(BUILD_TYPE),clblas)
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||||
else ifeq ($(BUILD_TYPE),hipblas)
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
ROCM_HOME ?= /opt/rocm
|
ROCM_HOME ?= /opt/rocm
|
||||||
ROCM_PATH ?= /opt/rocm
|
ROCM_PATH ?= /opt/rocm
|
||||||
@@ -90,33 +89,12 @@ else
|
|||||||
LLAMA_VERSION=$(LLAMA_VERSION) $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server
|
LLAMA_VERSION=$(LLAMA_VERSION) $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server
|
||||||
endif
|
endif
|
||||||
|
|
||||||
llama-cpp-avx2: llama.cpp
|
llama-cpp: llama.cpp
|
||||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build
|
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-build
|
||||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build purge
|
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-build purge
|
||||||
$(info ${GREEN}I llama-cpp build info:avx2${RESET})
|
$(info ${GREEN}I llama-cpp build info:${RESET})
|
||||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx2-build" build-llama-cpp-grpc-server
|
CMAKE_ARGS="$(CMAKE_ARGS)" $(MAKE) VARIANT="llama-cpp-build" build-llama-cpp-grpc-server
|
||||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build/grpc-server llama-cpp-avx2
|
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-build/grpc-server llama-cpp
|
||||||
|
|
||||||
llama-cpp-avx512: llama.cpp
|
|
||||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build
|
|
||||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build purge
|
|
||||||
$(info ${GREEN}I llama-cpp build info:avx512${RESET})
|
|
||||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx512-build" build-llama-cpp-grpc-server
|
|
||||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build/grpc-server llama-cpp-avx512
|
|
||||||
|
|
||||||
llama-cpp-avx: llama.cpp
|
|
||||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build
|
|
||||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build purge
|
|
||||||
$(info ${GREEN}I llama-cpp build info:avx${RESET})
|
|
||||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) VARIANT="llama-cpp-avx-build" build-llama-cpp-grpc-server
|
|
||||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build/grpc-server llama-cpp-avx
|
|
||||||
|
|
||||||
llama-cpp-fallback: llama.cpp
|
|
||||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build
|
|
||||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build purge
|
|
||||||
$(info ${GREEN}I llama-cpp build info:fallback${RESET})
|
|
||||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) VARIANT="llama-cpp-fallback-build" build-llama-cpp-grpc-server
|
|
||||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build/grpc-server llama-cpp-fallback
|
|
||||||
|
|
||||||
llama-cpp-grpc: llama.cpp
|
llama-cpp-grpc: llama.cpp
|
||||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build
|
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build
|
||||||
@@ -161,8 +139,8 @@ grpc-server: llama.cpp llama.cpp/tools/grpc-server
|
|||||||
@echo "Building grpc-server with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
|
@echo "Building grpc-server with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
|
||||||
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
+bash -c "source $(ONEAPI_VARS); \
|
+bash -c "source $(ONEAPI_VARS); \
|
||||||
cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET)"
|
cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release $(TARGET)"
|
||||||
else
|
else
|
||||||
+cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET)
|
+cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release $(TARGET)
|
||||||
endif
|
endif
|
||||||
cp llama.cpp/build/bin/grpc-server .
|
cp llama.cpp/build/bin/grpc-server .
|
||||||
|
|||||||
@@ -313,11 +313,9 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||||
else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
||||||
|
|
||||||
if ( request->yarnextfactor() != 0.0f ) {
|
if ( request->yarnextfactor() != 0.0f ) {
|
||||||
params.yarn_ext_factor = request->yarnextfactor();
|
params.yarn_ext_factor = request->yarnextfactor();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,34 +6,9 @@ CURDIR=$(dirname "$(realpath $0)")
|
|||||||
|
|
||||||
cd /
|
cd /
|
||||||
|
|
||||||
echo "CPU info:"
|
BINARY=llama-cpp
|
||||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
|
||||||
grep -e "flags" /proc/cpuinfo | head -1
|
|
||||||
|
|
||||||
BINARY=llama-cpp-fallback
|
|
||||||
|
|
||||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
|
||||||
echo "CPU: AVX found OK"
|
|
||||||
if [ -e $CURDIR/llama-cpp-avx ]; then
|
|
||||||
BINARY=llama-cpp-avx
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
|
||||||
echo "CPU: AVX2 found OK"
|
|
||||||
if [ -e $CURDIR/llama-cpp-avx2 ]; then
|
|
||||||
BINARY=llama-cpp-avx2
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check avx 512
|
|
||||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
|
||||||
echo "CPU: AVX512F found OK"
|
|
||||||
if [ -e $CURDIR/llama-cpp-avx512 ]; then
|
|
||||||
BINARY=llama-cpp-avx512
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
## P2P/GRPC mode
|
||||||
if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then
|
if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then
|
||||||
if [ -e $CURDIR/llama-cpp-grpc ]; then
|
if [ -e $CURDIR/llama-cpp-grpc ]; then
|
||||||
BINARY=llama-cpp-grpc
|
BINARY=llama-cpp-grpc
|
||||||
@@ -56,6 +31,3 @@ fi
|
|||||||
|
|
||||||
echo "Using binary: $BINARY"
|
echo "Using binary: $BINARY"
|
||||||
exec $CURDIR/$BINARY "$@"
|
exec $CURDIR/$BINARY "$@"
|
||||||
|
|
||||||
# In case we fail execing, just run fallback
|
|
||||||
exec $CURDIR/llama-cpp-fallback "$@"
|
|
||||||
@@ -18,11 +18,11 @@ GO_TAGS?=
|
|||||||
LD_FLAGS?=
|
LD_FLAGS?=
|
||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/richiejp/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=5900ef6605c6fbf7934239f795c13c97bc993853
|
STABLEDIFFUSION_GGML_VERSION?=53e3b17eb3d0b5760ced06a1f98320b68b34aaae
|
||||||
|
|
||||||
# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
|
# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DGGML_MAX_NAME=128 -DSD_USE_SYSTEM_GGML=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|
||||||
ifeq ($(NATIVE),false)
|
ifeq ($(NATIVE),false)
|
||||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
@@ -91,18 +91,23 @@ endif
|
|||||||
# (ggml can have different backends cpu, cuda, etc., each backend generates a .a archive)
|
# (ggml can have different backends cpu, cuda, etc., each backend generates a .a archive)
|
||||||
GGML_ARCHIVE_DIR := build/ggml/src/
|
GGML_ARCHIVE_DIR := build/ggml/src/
|
||||||
ALL_ARCHIVES := $(shell find $(GGML_ARCHIVE_DIR) -type f -name '*.a')
|
ALL_ARCHIVES := $(shell find $(GGML_ARCHIVE_DIR) -type f -name '*.a')
|
||||||
ALL_OBJS := $(shell find $(GGML_ARCHIVE_DIR) -type f -name '*.o')
|
|
||||||
|
|
||||||
# Name of the single merged library
|
# Name of the single merged library
|
||||||
COMBINED_LIB := libggmlall.a
|
COMBINED_LIB := libggmlall.a
|
||||||
|
|
||||||
# Instead of using the archives generated by GGML, use the object files directly to avoid overwriting objects with the same base name
|
# Rule to merge all the .a files into one
|
||||||
$(COMBINED_LIB): $(ALL_ARCHIVES)
|
$(COMBINED_LIB): $(ALL_ARCHIVES)
|
||||||
@echo "Merging all .o into $(COMBINED_LIB): $(ALL_OBJS)"
|
@echo "Merging all .a into $(COMBINED_LIB)"
|
||||||
rm -f $@
|
rm -f $@
|
||||||
ar -qc $@ $(ALL_OBJS)
|
mkdir -p merge-tmp
|
||||||
|
for a in $(ALL_ARCHIVES); do \
|
||||||
|
( cd merge-tmp && ar x ../$$a ); \
|
||||||
|
done
|
||||||
|
( cd merge-tmp && ar rcs ../$@ *.o )
|
||||||
# Ensure we have a proper index
|
# Ensure we have a proper index
|
||||||
ranlib $@
|
ranlib $@
|
||||||
|
# Clean up
|
||||||
|
rm -rf merge-tmp
|
||||||
|
|
||||||
build/libstable-diffusion.a:
|
build/libstable-diffusion.a:
|
||||||
@echo "Building SD with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
|
@echo "Building SD with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
#define GGML_MAX_NAME 128
|
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
@@ -7,7 +5,6 @@
|
|||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <filesystem>
|
|
||||||
#include "gosd.h"
|
#include "gosd.h"
|
||||||
|
|
||||||
// #include "preprocessing.hpp"
|
// #include "preprocessing.hpp"
|
||||||
@@ -56,43 +53,9 @@ sd_ctx_t* sd_c;
|
|||||||
|
|
||||||
sample_method_t sample_method;
|
sample_method_t sample_method;
|
||||||
|
|
||||||
// Copied from the upstream CLI
|
int load_model(char *model, char* options[], int threads, int diff) {
|
||||||
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|
||||||
//SDParams* params = (SDParams*)data;
|
|
||||||
const char* level_str;
|
|
||||||
|
|
||||||
if (!log /*|| (!params->verbose && level <= SD_LOG_DEBUG)*/) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (level) {
|
|
||||||
case SD_LOG_DEBUG:
|
|
||||||
level_str = "DEBUG";
|
|
||||||
break;
|
|
||||||
case SD_LOG_INFO:
|
|
||||||
level_str = "INFO";
|
|
||||||
break;
|
|
||||||
case SD_LOG_WARN:
|
|
||||||
level_str = "WARN";
|
|
||||||
break;
|
|
||||||
case SD_LOG_ERROR:
|
|
||||||
level_str = "ERROR";
|
|
||||||
break;
|
|
||||||
default: /* Potential future-proofing */
|
|
||||||
level_str = "?????";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "[%-5s] ", level_str);
|
|
||||||
fputs(log, stderr);
|
|
||||||
fflush(stderr);
|
|
||||||
}
|
|
||||||
|
|
||||||
int load_model(char *model, char *model_path, char* options[], int threads, int diff) {
|
|
||||||
fprintf (stderr, "Loading model!\n");
|
fprintf (stderr, "Loading model!\n");
|
||||||
|
|
||||||
sd_set_log_callback(sd_log_cb, NULL);
|
|
||||||
|
|
||||||
char *stableDiffusionModel = "";
|
char *stableDiffusionModel = "";
|
||||||
if (diff == 1 ) {
|
if (diff == 1 ) {
|
||||||
stableDiffusionModel = model;
|
stableDiffusionModel = model;
|
||||||
@@ -106,10 +69,6 @@ int load_model(char *model, char *model_path, char* options[], int threads, int
|
|||||||
char *vae_path = "";
|
char *vae_path = "";
|
||||||
char *scheduler = "";
|
char *scheduler = "";
|
||||||
char *sampler = "";
|
char *sampler = "";
|
||||||
char *lora_dir = model_path;
|
|
||||||
bool lora_dir_allocated = false;
|
|
||||||
|
|
||||||
fprintf(stderr, "parsing options\n");
|
|
||||||
|
|
||||||
// If options is not NULL, parse options
|
// If options is not NULL, parse options
|
||||||
for (int i = 0; options[i] != NULL; i++) {
|
for (int i = 0; options[i] != NULL; i++) {
|
||||||
@@ -137,29 +96,12 @@ int load_model(char *model, char *model_path, char* options[], int threads, int
|
|||||||
if (!strcmp(optname, "sampler")) {
|
if (!strcmp(optname, "sampler")) {
|
||||||
sampler = optval;
|
sampler = optval;
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "lora_dir")) {
|
|
||||||
// Path join with model dir
|
|
||||||
if (model_path && strlen(model_path) > 0) {
|
|
||||||
std::filesystem::path model_path_str(model_path);
|
|
||||||
std::filesystem::path lora_path(optval);
|
|
||||||
std::filesystem::path full_lora_path = model_path_str / lora_path;
|
|
||||||
lora_dir = strdup(full_lora_path.string().c_str());
|
|
||||||
lora_dir_allocated = true;
|
|
||||||
fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir);
|
|
||||||
} else {
|
|
||||||
lora_dir = optval;
|
|
||||||
fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "parsed options\n");
|
|
||||||
|
|
||||||
int sample_method_found = -1;
|
int sample_method_found = -1;
|
||||||
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
|
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
|
||||||
if (!strcmp(sampler, sample_method_str[m])) {
|
if (!strcmp(sampler, sample_method_str[m])) {
|
||||||
sample_method_found = m;
|
sample_method_found = m;
|
||||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sample_method_found == -1) {
|
if (sample_method_found == -1) {
|
||||||
@@ -169,7 +111,7 @@ int load_model(char *model, char *model_path, char* options[], int threads, int
|
|||||||
sample_method = (sample_method_t)sample_method_found;
|
sample_method = (sample_method_t)sample_method_found;
|
||||||
|
|
||||||
int schedule_found = -1;
|
int schedule_found = -1;
|
||||||
for (int d = 0; d < SCHEDULE_COUNT; d++) {
|
for (int d = 0; d < N_SCHEDULES; d++) {
|
||||||
if (!strcmp(scheduler, schedule_str[d])) {
|
if (!strcmp(scheduler, schedule_str[d])) {
|
||||||
schedule_found = d;
|
schedule_found = d;
|
||||||
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
||||||
@@ -183,50 +125,43 @@ int load_model(char *model, char *model_path, char* options[], int threads, int
|
|||||||
}
|
}
|
||||||
|
|
||||||
schedule_t schedule = (schedule_t)schedule_found;
|
schedule_t schedule = (schedule_t)schedule_found;
|
||||||
|
|
||||||
fprintf (stderr, "Creating context\n");
|
fprintf (stderr, "Creating context\n");
|
||||||
sd_ctx_params_t ctx_params;
|
sd_ctx_t* sd_ctx = new_sd_ctx(model,
|
||||||
sd_ctx_params_init(&ctx_params);
|
clip_l_path,
|
||||||
ctx_params.model_path = model;
|
clip_g_path,
|
||||||
ctx_params.clip_l_path = clip_l_path;
|
t5xxl_path,
|
||||||
ctx_params.clip_g_path = clip_g_path;
|
stableDiffusionModel,
|
||||||
ctx_params.t5xxl_path = t5xxl_path;
|
vae_path,
|
||||||
ctx_params.diffusion_model_path = stableDiffusionModel;
|
"",
|
||||||
ctx_params.vae_path = vae_path;
|
"",
|
||||||
ctx_params.taesd_path = "";
|
"",
|
||||||
ctx_params.control_net_path = "";
|
"",
|
||||||
ctx_params.lora_model_dir = lora_dir;
|
"",
|
||||||
ctx_params.embedding_dir = "";
|
false,
|
||||||
ctx_params.stacked_id_embed_dir = "";
|
false,
|
||||||
ctx_params.vae_decode_only = false;
|
false,
|
||||||
ctx_params.vae_tiling = false;
|
threads,
|
||||||
ctx_params.free_params_immediately = false;
|
SD_TYPE_COUNT,
|
||||||
ctx_params.n_threads = threads;
|
STD_DEFAULT_RNG,
|
||||||
ctx_params.rng_type = STD_DEFAULT_RNG;
|
schedule,
|
||||||
ctx_params.schedule = schedule;
|
false,
|
||||||
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
|
false,
|
||||||
|
false,
|
||||||
|
false);
|
||||||
|
|
||||||
if (sd_ctx == NULL) {
|
if (sd_ctx == NULL) {
|
||||||
fprintf (stderr, "failed loading model (generic error)\n");
|
fprintf (stderr, "failed loading model (generic error)\n");
|
||||||
// Clean up allocated memory
|
|
||||||
if (lora_dir_allocated && lora_dir) {
|
|
||||||
free(lora_dir);
|
|
||||||
}
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
fprintf (stderr, "Created context: OK\n");
|
fprintf (stderr, "Created context: OK\n");
|
||||||
|
|
||||||
sd_c = sd_ctx;
|
sd_c = sd_ctx;
|
||||||
|
|
||||||
// Clean up allocated memory
|
|
||||||
if (lora_dir_allocated && lora_dir) {
|
|
||||||
free(lora_dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
|
||||||
|
|
||||||
sd_image_t* results;
|
sd_image_t* results;
|
||||||
|
|
||||||
@@ -234,202 +169,37 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
|
|
||||||
fprintf (stderr, "Generating image\n");
|
fprintf (stderr, "Generating image\n");
|
||||||
|
|
||||||
sd_img_gen_params_t p;
|
results = txt2img(sd_c,
|
||||||
sd_img_gen_params_init(&p);
|
text,
|
||||||
|
negativeText,
|
||||||
p.prompt = text;
|
-1, //clip_skip
|
||||||
p.negative_prompt = negativeText;
|
cfg_scale, // sfg_scale
|
||||||
p.guidance.txt_cfg = cfg_scale;
|
3.5f,
|
||||||
p.guidance.slg.layers = skip_layers.data();
|
0, // eta
|
||||||
p.guidance.slg.layer_count = skip_layers.size();
|
width,
|
||||||
p.width = width;
|
height,
|
||||||
p.height = height;
|
sample_method,
|
||||||
p.sample_method = sample_method;
|
steps,
|
||||||
p.sample_steps = steps;
|
seed,
|
||||||
p.seed = seed;
|
1,
|
||||||
p.input_id_images_path = "";
|
NULL,
|
||||||
|
0.9f,
|
||||||
// Handle input image for img2img
|
20.f,
|
||||||
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
false,
|
||||||
bool has_mask_image = (mask_image != NULL && strlen(mask_image) > 0);
|
"",
|
||||||
|
skip_layers.data(),
|
||||||
uint8_t* input_image_buffer = NULL;
|
skip_layers.size(),
|
||||||
uint8_t* mask_image_buffer = NULL;
|
0,
|
||||||
std::vector<uint8_t> default_mask_image_vec;
|
0.01,
|
||||||
|
0.2);
|
||||||
if (has_input_image) {
|
|
||||||
fprintf(stderr, "Loading input image: %s\n", src_image);
|
|
||||||
|
|
||||||
int c = 0;
|
|
||||||
int img_width = 0;
|
|
||||||
int img_height = 0;
|
|
||||||
input_image_buffer = stbi_load(src_image, &img_width, &img_height, &c, 3);
|
|
||||||
if (input_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to load input image from '%s'\n", src_image);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (c < 3) {
|
|
||||||
fprintf(stderr, "Input image must have at least 3 channels, got %d\n", c);
|
|
||||||
free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize input image if dimensions don't match
|
|
||||||
if (img_width != width || img_height != height) {
|
|
||||||
fprintf(stderr, "Resizing input image from %dx%d to %dx%d\n", img_width, img_height, width, height);
|
|
||||||
|
|
||||||
uint8_t* resized_image_buffer = (uint8_t*)malloc(height * width * 3);
|
|
||||||
if (resized_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to allocate memory for resized image\n");
|
|
||||||
free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
stbir_resize(input_image_buffer, img_width, img_height, 0,
|
|
||||||
resized_image_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
|
||||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
|
||||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
|
||||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
|
||||||
STBIR_COLORSPACE_SRGB, nullptr);
|
|
||||||
|
|
||||||
free(input_image_buffer);
|
|
||||||
input_image_buffer = resized_image_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
|
||||||
p.strength = strength;
|
|
||||||
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
|
||||||
} else {
|
|
||||||
// No input image, use empty image for text-to-image
|
|
||||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
|
||||||
p.strength = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle mask image for inpainting
|
|
||||||
if (has_mask_image) {
|
|
||||||
fprintf(stderr, "Loading mask image: %s\n", mask_image);
|
|
||||||
|
|
||||||
int c = 0;
|
|
||||||
int mask_width = 0;
|
|
||||||
int mask_height = 0;
|
|
||||||
mask_image_buffer = stbi_load(mask_image, &mask_width, &mask_height, &c, 1);
|
|
||||||
if (mask_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to load mask image from '%s'\n", mask_image);
|
|
||||||
if (input_image_buffer) free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize mask if dimensions don't match
|
|
||||||
if (mask_width != width || mask_height != height) {
|
|
||||||
fprintf(stderr, "Resizing mask image from %dx%d to %dx%d\n", mask_width, mask_height, width, height);
|
|
||||||
|
|
||||||
uint8_t* resized_mask_buffer = (uint8_t*)malloc(height * width);
|
|
||||||
if (resized_mask_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to allocate memory for resized mask\n");
|
|
||||||
free(mask_image_buffer);
|
|
||||||
if (input_image_buffer) free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
stbir_resize(mask_image_buffer, mask_width, mask_height, 0,
|
|
||||||
resized_mask_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
|
||||||
1, STBIR_ALPHA_CHANNEL_NONE, 0,
|
|
||||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
|
||||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
|
||||||
STBIR_COLORSPACE_SRGB, nullptr);
|
|
||||||
|
|
||||||
free(mask_image_buffer);
|
|
||||||
mask_image_buffer = resized_mask_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
|
||||||
fprintf(stderr, "Using inpainting with mask\n");
|
|
||||||
} else {
|
|
||||||
// No mask image, create default full mask
|
|
||||||
default_mask_image_vec.resize(width * height, 255);
|
|
||||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle reference images
|
|
||||||
std::vector<sd_image_t> ref_images_vec;
|
|
||||||
std::vector<uint8_t*> ref_image_buffers;
|
|
||||||
|
|
||||||
if (ref_images_count > 0 && ref_images != NULL) {
|
|
||||||
fprintf(stderr, "Loading %d reference images\n", ref_images_count);
|
|
||||||
|
|
||||||
for (int i = 0; i < ref_images_count; i++) {
|
|
||||||
if (ref_images[i] == NULL || strlen(ref_images[i]) == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "Loading reference image %d: %s\n", i + 1, ref_images[i]);
|
|
||||||
|
|
||||||
int c = 0;
|
|
||||||
int ref_width = 0;
|
|
||||||
int ref_height = 0;
|
|
||||||
uint8_t* ref_image_buffer = stbi_load(ref_images[i], &ref_width, &ref_height, &c, 3);
|
|
||||||
if (ref_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to load reference image from '%s'\n", ref_images[i]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (c < 3) {
|
|
||||||
fprintf(stderr, "Reference image must have at least 3 channels, got %d\n", c);
|
|
||||||
free(ref_image_buffer);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize reference image if dimensions don't match
|
|
||||||
if (ref_width != width || ref_height != height) {
|
|
||||||
fprintf(stderr, "Resizing reference image from %dx%d to %dx%d\n", ref_width, ref_height, width, height);
|
|
||||||
|
|
||||||
uint8_t* resized_ref_buffer = (uint8_t*)malloc(height * width * 3);
|
|
||||||
if (resized_ref_buffer == NULL) {
|
|
||||||
fprintf(stderr, "Failed to allocate memory for resized reference image\n");
|
|
||||||
free(ref_image_buffer);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
stbir_resize(ref_image_buffer, ref_width, ref_height, 0,
|
|
||||||
resized_ref_buffer, width, height, 0, STBIR_TYPE_UINT8,
|
|
||||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
|
||||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
|
||||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
|
||||||
STBIR_COLORSPACE_SRGB, nullptr);
|
|
||||||
|
|
||||||
free(ref_image_buffer);
|
|
||||||
ref_image_buffer = resized_ref_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
ref_image_buffers.push_back(ref_image_buffer);
|
|
||||||
ref_images_vec.push_back({(uint32_t)width, (uint32_t)height, 3, ref_image_buffer});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ref_images_vec.empty()) {
|
|
||||||
p.ref_images = ref_images_vec.data();
|
|
||||||
p.ref_images_count = ref_images_vec.size();
|
|
||||||
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results = generate_image(sd_c, &p);
|
|
||||||
|
|
||||||
if (results == NULL) {
|
if (results == NULL) {
|
||||||
fprintf (stderr, "NO results\n");
|
fprintf (stderr, "NO results\n");
|
||||||
if (input_image_buffer) free(input_image_buffer);
|
|
||||||
if (mask_image_buffer) free(mask_image_buffer);
|
|
||||||
for (auto buffer : ref_image_buffers) {
|
|
||||||
if (buffer) free(buffer);
|
|
||||||
}
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (results[0].data == NULL) {
|
if (results[0].data == NULL) {
|
||||||
fprintf (stderr, "Results with no data\n");
|
fprintf (stderr, "Results with no data\n");
|
||||||
if (input_image_buffer) free(input_image_buffer);
|
|
||||||
if (mask_image_buffer) free(mask_image_buffer);
|
|
||||||
for (auto buffer : ref_image_buffers) {
|
|
||||||
if (buffer) free(buffer);
|
|
||||||
}
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,15 +215,11 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
results[0].data, 0, NULL);
|
results[0].data, 0, NULL);
|
||||||
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
||||||
|
|
||||||
// Clean up
|
// TODO: free results. Why does it crash?
|
||||||
|
|
||||||
free(results[0].data);
|
free(results[0].data);
|
||||||
results[0].data = NULL;
|
results[0].data = NULL;
|
||||||
free(results);
|
free(results);
|
||||||
if (input_image_buffer) free(input_image_buffer);
|
|
||||||
if (mask_image_buffer) free(mask_image_buffer);
|
|
||||||
for (auto buffer : ref_image_buffers) {
|
|
||||||
if (buffer) free(buffer);
|
|
||||||
}
|
|
||||||
fprintf (stderr, "gen_image is done", dst);
|
fprintf (stderr, "gen_image is done", dst);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@@ -29,21 +29,16 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
|||||||
|
|
||||||
sd.threads = int(opts.Threads)
|
sd.threads = int(opts.Threads)
|
||||||
|
|
||||||
modelPath := opts.ModelPath
|
|
||||||
|
|
||||||
modelFile := C.CString(opts.ModelFile)
|
modelFile := C.CString(opts.ModelFile)
|
||||||
defer C.free(unsafe.Pointer(modelFile))
|
defer C.free(unsafe.Pointer(modelFile))
|
||||||
|
|
||||||
modelPathC := C.CString(modelPath)
|
|
||||||
defer C.free(unsafe.Pointer(modelPathC))
|
|
||||||
|
|
||||||
var options **C.char
|
var options **C.char
|
||||||
// prepare the options array to pass to C
|
// prepare the options array to pass to C
|
||||||
|
|
||||||
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||||
length := C.size_t(len(opts.Options))
|
length := C.size_t(len(opts.Options))
|
||||||
options = (**C.char)(C.malloc((length + 1) * size))
|
options = (**C.char)(C.malloc(length * size))
|
||||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0 : len(opts.Options)+1 : len(opts.Options)+1]
|
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0:len(opts.Options):len(opts.Options)]
|
||||||
|
|
||||||
var diffusionModel int
|
var diffusionModel int
|
||||||
|
|
||||||
@@ -71,11 +66,10 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
|||||||
for i, x := range oo {
|
for i, x := range oo {
|
||||||
view[i] = C.CString(x)
|
view[i] = C.CString(x)
|
||||||
}
|
}
|
||||||
view[len(oo)] = nil
|
|
||||||
|
|
||||||
sd.cfgScale = opts.CFGScale
|
sd.cfgScale = opts.CFGScale
|
||||||
|
|
||||||
ret := C.load_model(modelFile, modelPathC, options, C.int(opts.Threads), C.int(diffusionModel))
|
ret := C.load_model(modelFile, options, C.int(opts.Threads), C.int(diffusionModel))
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
return fmt.Errorf("could not load model")
|
return fmt.Errorf("could not load model")
|
||||||
}
|
}
|
||||||
@@ -93,56 +87,7 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
|||||||
negative := C.CString(opts.NegativePrompt)
|
negative := C.CString(opts.NegativePrompt)
|
||||||
defer C.free(unsafe.Pointer(negative))
|
defer C.free(unsafe.Pointer(negative))
|
||||||
|
|
||||||
// Handle source image path
|
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale))
|
||||||
var srcImage *C.char
|
|
||||||
if opts.Src != "" {
|
|
||||||
srcImage = C.CString(opts.Src)
|
|
||||||
defer C.free(unsafe.Pointer(srcImage))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle mask image path
|
|
||||||
var maskImage *C.char
|
|
||||||
if opts.EnableParameters != "" {
|
|
||||||
// Parse EnableParameters for mask path if provided
|
|
||||||
// This is a simple approach - in a real implementation you might want to parse JSON
|
|
||||||
if strings.Contains(opts.EnableParameters, "mask:") {
|
|
||||||
parts := strings.Split(opts.EnableParameters, "mask:")
|
|
||||||
if len(parts) > 1 {
|
|
||||||
maskPath := strings.TrimSpace(parts[1])
|
|
||||||
if maskPath != "" {
|
|
||||||
maskImage = C.CString(maskPath)
|
|
||||||
defer C.free(unsafe.Pointer(maskImage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle reference images
|
|
||||||
var refImages **C.char
|
|
||||||
var refImagesCount C.int
|
|
||||||
if len(opts.RefImages) > 0 {
|
|
||||||
refImagesCount = C.int(len(opts.RefImages))
|
|
||||||
// Allocate array of C strings
|
|
||||||
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
|
||||||
refImages = (**C.char)(C.malloc((C.size_t(len(opts.RefImages)) + 1) * size))
|
|
||||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(refImages))[0 : len(opts.RefImages)+1 : len(opts.RefImages)+1]
|
|
||||||
|
|
||||||
for i, refImagePath := range opts.RefImages {
|
|
||||||
view[i] = C.CString(refImagePath)
|
|
||||||
defer C.free(unsafe.Pointer(view[i]))
|
|
||||||
}
|
|
||||||
view[len(opts.RefImages)] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default strength for img2img (0.75 is a good default)
|
|
||||||
strength := C.float(0.75)
|
|
||||||
if opts.Src != "" {
|
|
||||||
// If we have a source image, use img2img mode
|
|
||||||
// You could also parse strength from EnableParameters if needed
|
|
||||||
strength = C.float(0.75)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale), srcImage, strength, maskImage, refImages, refImagesCount)
|
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
return fmt.Errorf("inference failed")
|
return fmt.Errorf("inference failed")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int load_model(char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
int load_model(char *model, char* options[], int threads, int diffusionModel);
|
||||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -6,7 +6,7 @@ CMAKE_ARGS?=
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=b02242d0adb5c6c4896d59ac86d9ec9fe0d0fe33
|
WHISPER_CPP_VERSION?=1f5cf0b2888402d57bb17b2029b2caa97e5f3baf
|
||||||
|
|
||||||
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
|
export WHISPER_CMAKE_ARGS?=-DBUILD_SHARED_LIBS=OFF
|
||||||
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
|
export WHISPER_DIR=$(abspath ./sources/whisper.cpp)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -57,11 +57,6 @@ function init() {
|
|||||||
# - hipblas
|
# - hipblas
|
||||||
# - intel
|
# - intel
|
||||||
function getBuildProfile() {
|
function getBuildProfile() {
|
||||||
if [ "x${BUILD_TYPE}" == "xl4t" ]; then
|
|
||||||
echo "l4t"
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# First check if we are a cublas build, and if so report the correct build profile
|
# First check if we are a cublas build, and if so report the correct build profile
|
||||||
if [ x"${BUILD_TYPE}" == "xcublas" ]; then
|
if [ x"${BUILD_TYPE}" == "xcublas" ]; then
|
||||||
if [ ! -z ${CUDA_MAJOR_VERSION} ]; then
|
if [ ! -z ${CUDA_MAJOR_VERSION} ]; then
|
||||||
@@ -116,7 +111,7 @@ function ensureVenv() {
|
|||||||
# - requirements-${BUILD_TYPE}.txt
|
# - requirements-${BUILD_TYPE}.txt
|
||||||
# - requirements-${BUILD_PROFILE}.txt
|
# - requirements-${BUILD_PROFILE}.txt
|
||||||
#
|
#
|
||||||
# BUILD_PROFILE is a pore specific version of BUILD_TYPE, ex: cuda-11 or cuda-12
|
# BUILD_PROFILE is a pore specific version of BUILD_TYPE, ex: cuda11 or cuda12
|
||||||
# it can also include some options that we do not have BUILD_TYPES for, ex: intel
|
# it can also include some options that we do not have BUILD_TYPES for, ex: intel
|
||||||
#
|
#
|
||||||
# NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index.
|
# NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index.
|
||||||
|
|||||||
@@ -8,6 +8,4 @@ else
|
|||||||
source $backend_dir/../common/libbackend.sh
|
source $backend_dir/../common/libbackend.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ensureVenv
|
|
||||||
|
|
||||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||||
@@ -65,19 +65,6 @@ from diffusers.schedulers import (
|
|||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_float(s):
|
|
||||||
try:
|
|
||||||
float(s)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_int(s):
|
|
||||||
try:
|
|
||||||
int(s)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||||
# Credits to https://github.com/neggles
|
# Credits to https://github.com/neggles
|
||||||
@@ -182,24 +169,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if ":" not in opt:
|
if ":" not in opt:
|
||||||
continue
|
continue
|
||||||
key, value = opt.split(":")
|
key, value = opt.split(":")
|
||||||
# if value is a number, convert it to the appropriate type
|
|
||||||
if is_float(value):
|
|
||||||
value = float(value)
|
|
||||||
elif is_int(value):
|
|
||||||
value = int(value)
|
|
||||||
self.options[key] = value
|
self.options[key] = value
|
||||||
|
|
||||||
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
|
||||||
if "torch_dtype" in self.options:
|
|
||||||
if self.options["torch_dtype"] == "fp16":
|
|
||||||
torchType = torch.float16
|
|
||||||
elif self.options["torch_dtype"] == "bf16":
|
|
||||||
torchType = torch.bfloat16
|
|
||||||
elif self.options["torch_dtype"] == "fp32":
|
|
||||||
torchType = torch.float32
|
|
||||||
# remove it from options
|
|
||||||
del self.options["torch_dtype"]
|
|
||||||
|
|
||||||
print(f"Options: {self.options}", file=sys.stderr)
|
print(f"Options: {self.options}", file=sys.stderr)
|
||||||
|
|
||||||
local = False
|
local = False
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
git+https://github.com/huggingface/diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch==2.7.1
|
torch==2.4.1
|
||||||
optimum-quanto
|
optimum-quanto
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch==2.7.1+cu118
|
torch==2.4.1+cu118
|
||||||
git+https://github.com/huggingface/diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
torch==2.7.1
|
torch==2.4.1
|
||||||
git+https://github.com/huggingface/diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch==2.7.1+rocm6.3
|
torch==2.3.1+rocm6.0
|
||||||
torchvision==0.22.1+rocm6.3
|
torchvision==0.18.1+rocm6.0
|
||||||
git+https://github.com/huggingface/diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
intel-extension-for-pytorch==2.3.110+xpu
|
intel-extension-for-pytorch==2.3.110+xpu
|
||||||
torch==2.5.1+cxx11.abi
|
torch==2.3.1+cxx11.abi
|
||||||
torchvision==0.20.1+cxx11.abi
|
torchvision==0.18.1+cxx11.abi
|
||||||
oneccl_bind_pt==2.8.0+xpu
|
oneccl_bind_pt==2.3.100+xpu
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools
|
setuptools
|
||||||
git+https://github.com/huggingface/diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
|
||||||
torch
|
|
||||||
diffusers
|
|
||||||
transformers
|
|
||||||
accelerate
|
|
||||||
compel
|
|
||||||
peft
|
|
||||||
optimum-quanto
|
|
||||||
numpy<2
|
|
||||||
sentencepiece
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
.PHONY: kitten-tts
|
|
||||||
kitten-tts: protogen
|
|
||||||
bash install.sh
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run: protogen
|
|
||||||
@echo "Running kitten-tts..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "kitten-tts run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test: protogen
|
|
||||||
@echo "Testing kitten-tts..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "kitten-tts tested."
|
|
||||||
|
|
||||||
.PHONY: protogen
|
|
||||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
.PHONY: protogen-clean
|
|
||||||
protogen-clean:
|
|
||||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
backend_pb2_grpc.py backend_pb2.py:
|
|
||||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
|
||||||
|
|
||||||
.PHONY: clean
|
|
||||||
clean: protogen-clean
|
|
||||||
rm -rf venv __pycache__
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
This is an extra gRPC server of LocalAI for Kitten TTS
|
|
||||||
"""
|
|
||||||
from concurrent import futures
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from kittentts import KittenTTS
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
KITTEN_LANGUAGE = os.environ.get('KITTEN_LANGUAGE', None)
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
BackendServicer is the class that implements the gRPC service
|
|
||||||
"""
|
|
||||||
def Health(self, request, context):
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
|
|
||||||
# Get device
|
|
||||||
# device = "cuda" if request.CUDA else "cpu"
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print("CUDA is available", file=sys.stderr)
|
|
||||||
device = "cuda"
|
|
||||||
else:
|
|
||||||
print("CUDA is not available", file=sys.stderr)
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
if not torch.cuda.is_available() and request.CUDA:
|
|
||||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
|
||||||
|
|
||||||
self.AudioPath = None
|
|
||||||
# List available KittenTTS models
|
|
||||||
print("Available KittenTTS voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f")
|
|
||||||
if os.path.isabs(request.AudioPath):
|
|
||||||
self.AudioPath = request.AudioPath
|
|
||||||
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
|
||||||
# get base path of modelFile
|
|
||||||
modelFileBase = os.path.dirname(request.ModelFile)
|
|
||||||
# modify LoraAdapter to be relative to modelFileBase
|
|
||||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
|
||||||
|
|
||||||
try:
|
|
||||||
print("Preparing KittenTTS model, please wait", file=sys.stderr)
|
|
||||||
# Use the model name from request.Model, defaulting to "KittenML/kitten-tts-nano-0.1" if not specified
|
|
||||||
model_name = request.Model if request.Model else "KittenML/kitten-tts-nano-0.1"
|
|
||||||
self.tts = KittenTTS(model_name)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
# Implement your logic here for the LoadModel service
|
|
||||||
# Replace this with your desired response
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
try:
|
|
||||||
# KittenTTS doesn't use language parameter like TTS, so we ignore it
|
|
||||||
# For multi-speaker models, use voice parameter
|
|
||||||
voice = request.voice if request.voice else "expr-voice-2-f"
|
|
||||||
|
|
||||||
# Generate audio using KittenTTS
|
|
||||||
audio = self.tts.generate(request.text, voice=voice)
|
|
||||||
|
|
||||||
# Save the audio using soundfile
|
|
||||||
sf.write(request.dst, audio, 24000)
|
|
||||||
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(success=True)
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
||||||
options=[
|
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
])
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
|
||||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
|
||||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
|
||||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
|
||||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
|
||||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
|
||||||
fi
|
|
||||||
|
|
||||||
installRequirements
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
grpcio==1.71.0
|
|
||||||
protobuf
|
|
||||||
certifi
|
|
||||||
packaging==24.1
|
|
||||||
https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
startBackend $@
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(30)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
|
||||||
print(response)
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
runUnittests
|
|
||||||
@@ -1,18 +1,9 @@
|
|||||||
.PHONY: kokoro
|
.DEFAULT_GOAL := install
|
||||||
kokoro: protogen
|
|
||||||
|
.PHONY: install
|
||||||
|
install:
|
||||||
bash install.sh
|
bash install.sh
|
||||||
|
$(MAKE) protogen
|
||||||
.PHONY: run
|
|
||||||
run: protogen
|
|
||||||
@echo "Running kokoro..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "kokoro run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test: protogen
|
|
||||||
@echo "Testing kokoro..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "kokoro tested."
|
|
||||||
|
|
||||||
.PHONY: protogen
|
.PHONY: protogen
|
||||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||||
@@ -22,7 +13,7 @@ protogen-clean:
|
|||||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
backend_pb2_grpc.py backend_pb2.py:
|
backend_pb2_grpc.py backend_pb2.py:
|
||||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
bash protogen.sh
|
||||||
|
|
||||||
.PHONY: clean
|
.PHONY: clean
|
||||||
clean: protogen-clean
|
clean: protogen-clean
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
# Kokoro TTS Backend for LocalAI
|
|
||||||
|
|
||||||
This is a gRPC server backend for LocalAI that uses the Kokoro TTS pipeline.
|
|
||||||
|
|
||||||
## Creating a separate environment for kokoro project
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make kokoro
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing the gRPC server
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make test
|
|
||||||
```
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Lightweight TTS model with 82 million parameters
|
|
||||||
- Apache-licensed weights
|
|
||||||
- Fast and cost-efficient
|
|
||||||
- Multi-language support
|
|
||||||
- Multiple voice options
|
|
||||||
115
backend/python/kokoro/backend.py
Normal file → Executable file
115
backend/python/kokoro/backend.py
Normal file → Executable file
@@ -1,92 +1,101 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
This is an extra gRPC server of LocalAI for Kokoro TTS
|
Extra gRPC server for Kokoro models.
|
||||||
"""
|
"""
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import torch
|
|
||||||
from kokoro import KPipeline
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
|
from models import build_model
|
||||||
|
from kokoro import generate
|
||||||
|
import torch
|
||||||
|
|
||||||
|
SAMPLE_RATE = 22050
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
KOKORO_LANG_CODE = os.environ.get('KOKORO_LANG_CODE', 'a')
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
# Implement the BackendServicer class with the service methods
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
"""
|
"""
|
||||||
BackendServicer is the class that implements the gRPC service
|
A gRPC servicer for the backend service.
|
||||||
|
|
||||||
|
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||||
"""
|
"""
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that returns the health status of the backend service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A HealthRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Reply object that contains the health status of the backend service.
|
||||||
|
"""
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
def LoadModel(self, request, context):
|
||||||
# Get device
|
"""
|
||||||
if torch.cuda.is_available():
|
A gRPC method that loads a model into memory.
|
||||||
print("CUDA is available", file=sys.stderr)
|
|
||||||
device = "cuda"
|
|
||||||
else:
|
|
||||||
print("CUDA is not available", file=sys.stderr)
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
if not torch.cuda.is_available() and request.CUDA:
|
Args:
|
||||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
request: A LoadModelRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Result object that contains the result of the LoadModel operation.
|
||||||
|
"""
|
||||||
|
model_name = request.Model
|
||||||
try:
|
try:
|
||||||
print("Preparing Kokoro TTS pipeline, please wait", file=sys.stderr)
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
# empty dict
|
self.MODEL = build_model(request.ModelFile, device)
|
||||||
self.options = {}
|
|
||||||
options = request.Options
|
options = request.Options
|
||||||
# The options are a list of strings in this form optname:optvalue
|
# Find the voice from the options, options are a list of strings in this form optname:optvalue:
|
||||||
# We are storing all the options in a dict so we can use it later when
|
VOICE_NAME = None
|
||||||
# generating the images
|
|
||||||
for opt in options:
|
for opt in options:
|
||||||
if ":" not in opt:
|
if opt.startswith("voice:"):
|
||||||
continue
|
VOICE_NAME = opt.split(":")[1]
|
||||||
key, value = opt.split(":")
|
break
|
||||||
self.options[key] = value
|
if VOICE_NAME is None:
|
||||||
|
return backend_pb2.Result(success=False, message=f"No voice specified in options")
|
||||||
|
MODELPATH = request.ModelPath
|
||||||
|
# If voice name contains a plus, split it and load the two models and combine them
|
||||||
|
if "+" in VOICE_NAME:
|
||||||
|
voice1, voice2 = VOICE_NAME.split("+")
|
||||||
|
voice1 = torch.load(f'{MODELPATH}/{voice1}.pt', weights_only=True).to(device)
|
||||||
|
voice2 = torch.load(f'{MODELPATH}/{voice2}.pt', weights_only=True).to(device)
|
||||||
|
self.VOICEPACK = torch.mean(torch.stack([voice1, voice2]), dim=0)
|
||||||
|
else:
|
||||||
|
self.VOICEPACK = torch.load(f'{MODELPATH}/{VOICE_NAME}.pt', weights_only=True).to(device)
|
||||||
|
|
||||||
# Initialize Kokoro pipeline with language code
|
self.VOICE_NAME = VOICE_NAME
|
||||||
lang_code = self.options.get("lang_code", KOKORO_LANG_CODE)
|
|
||||||
self.pipeline = KPipeline(lang_code=lang_code)
|
print(f'Loaded voice: {VOICE_NAME}')
|
||||||
print(f"Kokoro TTS pipeline loaded with language code: {lang_code}", file=sys.stderr)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
return backend_pb2.Result(message="Kokoro TTS pipeline loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
|
model_name = request.model
|
||||||
|
if model_name == "":
|
||||||
|
return backend_pb2.Result(success=False, message="request.model is required")
|
||||||
try:
|
try:
|
||||||
# Get voice from request, default to 'af_heart' if not specified
|
audio, out_ps = generate(self.MODEL, request.text, self.VOICEPACK, lang=self.VOICE_NAME)
|
||||||
voice = request.voice if request.voice else 'af_heart'
|
print(out_ps)
|
||||||
|
sf.write(request.dst, audio, SAMPLE_RATE)
|
||||||
# Generate audio using Kokoro pipeline
|
|
||||||
generator = self.pipeline(request.text, voice=voice)
|
|
||||||
|
|
||||||
# Get the first (and typically only) audio segment
|
|
||||||
for i, (gs, ps, audio) in enumerate(generator):
|
|
||||||
# Save audio to the destination file
|
|
||||||
sf.write(request.dst, audio, 24000)
|
|
||||||
print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
|
|
||||||
# For now, we only process the first segment
|
|
||||||
# If you need to handle multiple segments, you might want to modify this
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
return backend_pb2.Result(success=True)
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
def serve(address):
|
def serve(address):
|
||||||
@@ -99,11 +108,11 @@ def serve(address):
|
|||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
print("[Kokoro] Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
|
||||||
# Define the signal handler function
|
# Define the signal handler function
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
print("Received termination signal. Shutting down...")
|
print("[Kokoro] Received termination signal. Shutting down...")
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
@@ -123,5 +132,5 @@ if __name__ == "__main__":
|
|||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
print(f"[Kokoro] startup: {args}", file=sys.stderr)
|
||||||
serve(args.addr)
|
serve(args.addr)
|
||||||
|
|||||||
524
backend/python/kokoro/istftnet.py
Normal file
524
backend/python/kokoro/istftnet.py
Normal file
@@ -0,0 +1,524 @@
|
|||||||
|
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/istftnet.py
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size*dilation - dilation)/2)
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
class AdaIN1d(nn.Module):
|
||||||
|
def __init__(self, style_dim, num_features):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
||||||
|
self.fc = nn.Linear(style_dim, num_features*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
return (1 + gamma) * self.norm(x) + beta
|
||||||
|
|
||||||
|
class AdaINResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||||
|
super(AdaINResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
self.adain1 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.adain2 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||||
|
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||||
|
xt = n1(x, s)
|
||||||
|
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = n2(xt, s)
|
||||||
|
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
class TorchSTFT(torch.nn.Module):
|
||||||
|
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||||
|
super().__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
||||||
|
|
||||||
|
def transform(self, input_data):
|
||||||
|
forward_transform = torch.stft(
|
||||||
|
input_data,
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||||
|
return_complex=True)
|
||||||
|
|
||||||
|
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||||
|
|
||||||
|
def inverse(self, magnitude, phase):
|
||||||
|
inverse_transform = torch.istft(
|
||||||
|
magnitude * torch.exp(phase * 1j),
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||||
|
|
||||||
|
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
self.magnitude, self.phase = self.transform(input_data)
|
||||||
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
|
return reconstruction
|
||||||
|
|
||||||
|
class SineGen(torch.nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(np.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0,
|
||||||
|
flag_for_pulse=False):
|
||||||
|
super(SineGen, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
self.flag_for_pulse = flag_for_pulse
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
def _f02sine(self, f0_values):
|
||||||
|
""" f0_values: (batchsize, length, dim)
|
||||||
|
where dim indicates fundamental tone and overtones
|
||||||
|
"""
|
||||||
|
# convert to F0 in rad. The integer part n can be ignored
|
||||||
|
# because 2 * np.pi * n doesn't affect phase
|
||||||
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
|
||||||
|
# initial phase noise (no noise for fundamental component)
|
||||||
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
||||||
|
device=f0_values.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
|
||||||
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
|
if not self.flag_for_pulse:
|
||||||
|
# # for normal case
|
||||||
|
|
||||||
|
# # To prevent torch.cumsum numerical overflow,
|
||||||
|
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
||||||
|
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
||||||
|
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
||||||
|
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||||
|
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||||
|
# cumsum_shift = torch.zeros_like(rad_values)
|
||||||
|
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||||
|
|
||||||
|
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||||
|
scale_factor=1/self.upsample_scale,
|
||||||
|
mode="linear").transpose(1, 2)
|
||||||
|
|
||||||
|
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||||
|
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||||
|
# cumsum_shift = torch.zeros_like(rad_values)
|
||||||
|
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||||
|
|
||||||
|
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||||
|
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
sines = torch.sin(phase)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If necessary, make sure that the first time step of every
|
||||||
|
# voiced segments is sin(pi) or cos(0)
|
||||||
|
# This is used for pulse-train generation
|
||||||
|
|
||||||
|
# identify the last time step in unvoiced segments
|
||||||
|
uv = self._f02uv(f0_values)
|
||||||
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
|
uv_1[:, -1, :] = 1
|
||||||
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
|
||||||
|
# get the instantanouse phase
|
||||||
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
|
# different batch needs to be processed differently
|
||||||
|
for idx in range(f0_values.shape[0]):
|
||||||
|
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||||
|
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||||
|
# stores the accumulation of i.phase within
|
||||||
|
# each voiced segments
|
||||||
|
tmp_cumsum[idx, :, :] = 0
|
||||||
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
|
||||||
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
|
# within the previous voiced segment.
|
||||||
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
|
||||||
|
# get the sines
|
||||||
|
sines = torch.cos(i_phase * 2 * np.pi)
|
||||||
|
return sines
|
||||||
|
|
||||||
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
||||||
|
device=f0.device)
|
||||||
|
# fundamental component
|
||||||
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
|
||||||
|
# generate sine waveforms
|
||||||
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
|
||||||
|
# generate uv signal
|
||||||
|
# uv = torch.ones(f0.shape)
|
||||||
|
# uv = uv * (f0 > self.voiced_threshold)
|
||||||
|
uv = self._f02uv(f0)
|
||||||
|
|
||||||
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
|
# . for voiced regions is self.noise_std
|
||||||
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
|
# first: set the unvoiced part to 0 by uv
|
||||||
|
# then: additive noise
|
||||||
|
sine_waves = sine_waves * uv + noise
|
||||||
|
return sine_waves, uv, noise
|
||||||
|
|
||||||
|
|
||||||
|
class SourceModuleHnNSF(torch.nn.Module):
|
||||||
|
""" SourceModule for hn-nsf
|
||||||
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
|
sampling_rate: sampling_rate in Hz
|
||||||
|
harmonic_num: number of harmonic above F0 (default: 0)
|
||||||
|
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||||
|
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||||
|
note that amplitude of noise in unvoiced is decided
|
||||||
|
by sine_amp
|
||||||
|
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
uv (batchsize, length, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
|
# to produce sine waveforms
|
||||||
|
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||||
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
|
||||||
|
# to merge source harmonics into a single excitation
|
||||||
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||||
|
self.l_tanh = torch.nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
"""
|
||||||
|
# source for harmonic branch
|
||||||
|
with torch.no_grad():
|
||||||
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
|
||||||
|
# source for noise branch, in the same shape as uv
|
||||||
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
|
return sine_merge, noise, uv
|
||||||
|
def padDiff(x):
|
||||||
|
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
resblock = AdaINResBlock1
|
||||||
|
|
||||||
|
self.m_source = SourceModuleHnNSF(
|
||||||
|
sampling_rate=24000,
|
||||||
|
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
||||||
|
harmonic_num=8, voiced_threshod=10)
|
||||||
|
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
||||||
|
self.noise_convs = nn.ModuleList()
|
||||||
|
self.noise_res = nn.ModuleList()
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(weight_norm(
|
||||||
|
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||||
|
k, u, padding=(k-u)//2)))
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel//(2**(i+1))
|
||||||
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(ch, k, d, style_dim))
|
||||||
|
|
||||||
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
|
||||||
|
if i + 1 < len(upsample_rates): #
|
||||||
|
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||||
|
self.noise_convs.append(Conv1d(
|
||||||
|
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||||
|
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
||||||
|
else:
|
||||||
|
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||||
|
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
||||||
|
|
||||||
|
|
||||||
|
self.post_n_fft = gen_istft_n_fft
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
||||||
|
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, s, f0):
|
||||||
|
with torch.no_grad():
|
||||||
|
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
|
||||||
|
har_source, noi_source, uv = self.m_source(f0)
|
||||||
|
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||||
|
har_spec, har_phase = self.stft.transform(har_source)
|
||||||
|
har = torch.cat([har_spec, har_phase], dim=1)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x_source = self.noise_convs[i](har)
|
||||||
|
x_source = self.noise_res[i](x_source, s)
|
||||||
|
|
||||||
|
x = self.ups[i](x)
|
||||||
|
if i == self.num_upsamples - 1:
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
|
||||||
|
x = x + x_source
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
|
return self.stft.inverse(spec, phase)
|
||||||
|
|
||||||
|
def fw_phase(self, x, s):
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
|
return spec, phase
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
|
||||||
|
class AdainResBlk1d(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||||
|
upsample='none', dropout_p=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.actv = actv
|
||||||
|
self.upsample_type = upsample
|
||||||
|
self.upsample = UpSample1d(upsample)
|
||||||
|
self.learned_sc = dim_in != dim_out
|
||||||
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
if upsample == 'none':
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
|
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||||
|
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||||
|
if self.learned_sc:
|
||||||
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||||
|
|
||||||
|
def _shortcut(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
if self.learned_sc:
|
||||||
|
x = self.conv1x1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _residual(self, x, s):
|
||||||
|
x = self.norm1(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv1(self.dropout(x))
|
||||||
|
x = self.norm2(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.conv2(self.dropout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
out = self._residual(x, s)
|
||||||
|
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, layer_type):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = layer_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_type == 'none':
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
||||||
|
resblock_kernel_sizes = [3,7,11],
|
||||||
|
upsample_rates = [10, 6],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
upsample_kernel_sizes=[20, 12],
|
||||||
|
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.decode = nn.ModuleList()
|
||||||
|
|
||||||
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||||
|
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||||
|
|
||||||
|
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
|
||||||
|
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
|
||||||
|
self.asr_res = nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||||
|
upsample_initial_channel, resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
||||||
|
|
||||||
|
def forward(self, asr, F0_curve, N, s):
|
||||||
|
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||||
|
N = self.N_conv(N.unsqueeze(1))
|
||||||
|
|
||||||
|
x = torch.cat([asr, F0, N], axis=1)
|
||||||
|
x = self.encode(x, s)
|
||||||
|
|
||||||
|
asr_res = self.asr_res(asr)
|
||||||
|
|
||||||
|
res = True
|
||||||
|
for block in self.decode:
|
||||||
|
if res:
|
||||||
|
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||||
|
x = block(x, s)
|
||||||
|
if block.upsample_type != "none":
|
||||||
|
res = False
|
||||||
|
|
||||||
|
x = self.generator(x, s, F0_curve)
|
||||||
|
return x
|
||||||
166
backend/python/kokoro/kokoro.py
Normal file
166
backend/python/kokoro/kokoro.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/kokoro.py
|
||||||
|
import phonemizer
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def split_num(num):
|
||||||
|
num = num.group()
|
||||||
|
if '.' in num:
|
||||||
|
return num
|
||||||
|
elif ':' in num:
|
||||||
|
h, m = [int(n) for n in num.split(':')]
|
||||||
|
if m == 0:
|
||||||
|
return f"{h} o'clock"
|
||||||
|
elif m < 10:
|
||||||
|
return f'{h} oh {m}'
|
||||||
|
return f'{h} {m}'
|
||||||
|
year = int(num[:4])
|
||||||
|
if year < 1100 or year % 1000 < 10:
|
||||||
|
return num
|
||||||
|
left, right = num[:2], int(num[2:4])
|
||||||
|
s = 's' if num.endswith('s') else ''
|
||||||
|
if 100 <= year % 1000 <= 999:
|
||||||
|
if right == 0:
|
||||||
|
return f'{left} hundred{s}'
|
||||||
|
elif right < 10:
|
||||||
|
return f'{left} oh {right}{s}'
|
||||||
|
return f'{left} {right}{s}'
|
||||||
|
|
||||||
|
def flip_money(m):
|
||||||
|
m = m.group()
|
||||||
|
bill = 'dollar' if m[0] == '$' else 'pound'
|
||||||
|
if m[-1].isalpha():
|
||||||
|
return f'{m[1:]} {bill}s'
|
||||||
|
elif '.' not in m:
|
||||||
|
s = '' if m[1:] == '1' else 's'
|
||||||
|
return f'{m[1:]} {bill}{s}'
|
||||||
|
b, c = m[1:].split('.')
|
||||||
|
s = '' if b == '1' else 's'
|
||||||
|
c = int(c.ljust(2, '0'))
|
||||||
|
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
||||||
|
return f'{b} {bill}{s} and {c} {coins}'
|
||||||
|
|
||||||
|
def point_num(num):
|
||||||
|
a, b = num.group().split('.')
|
||||||
|
return ' point '.join([a, ' '.join(b)])
|
||||||
|
|
||||||
|
def normalize_text(text):
|
||||||
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
|
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
||||||
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||||
|
text = text.replace('(', '«').replace(')', '»')
|
||||||
|
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
||||||
|
text = text.replace(a, b+' ')
|
||||||
|
text = re.sub(r'[^\S \n]', ' ', text)
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
||||||
|
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
||||||
|
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
||||||
|
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
||||||
|
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
||||||
|
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
||||||
|
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
||||||
|
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
||||||
|
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
||||||
|
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
||||||
|
text = re.sub(r'\d*\.\d+', point_num, text)
|
||||||
|
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
||||||
|
text = re.sub(r'(?<=\d)S', ' S', text)
|
||||||
|
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||||
|
text = re.sub(r"(?<=X')S\b", 's', text)
|
||||||
|
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
||||||
|
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def get_vocab():
|
||||||
|
_pad = "$"
|
||||||
|
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||||
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||||
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
|
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||||
|
dicts = {}
|
||||||
|
for i in range(len((symbols))):
|
||||||
|
dicts[symbols[i]] = i
|
||||||
|
return dicts
|
||||||
|
|
||||||
|
VOCAB = get_vocab()
|
||||||
|
def tokenize(ps):
|
||||||
|
return [i for i in map(VOCAB.get, ps) if i is not None]
|
||||||
|
|
||||||
|
phonemizers = dict(
|
||||||
|
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
||||||
|
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
||||||
|
)
|
||||||
|
def phonemize(text, lang, norm=True):
|
||||||
|
if norm:
|
||||||
|
text = normalize_text(text)
|
||||||
|
ps = phonemizers[lang].phonemize([text])
|
||||||
|
ps = ps[0] if ps else ''
|
||||||
|
# https://en.wiktionary.org/wiki/kokoro#English
|
||||||
|
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
||||||
|
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
||||||
|
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
||||||
|
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
||||||
|
if lang == 'a':
|
||||||
|
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
||||||
|
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
||||||
|
return ps.strip()
|
||||||
|
|
||||||
|
def length_to_mask(lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(model, tokens, ref_s, speed):
|
||||||
|
device = ref_s.device
|
||||||
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||||
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
s = ref_s[:, 128:]
|
||||||
|
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||||
|
x, _ = model.predictor.lstm(d)
|
||||||
|
duration = model.predictor.duration_proj(x)
|
||||||
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||||
|
c_frame = 0
|
||||||
|
for i in range(pred_aln_trg.size(0)):
|
||||||
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||||
|
c_frame += pred_dur[0,i].item()
|
||||||
|
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||||
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||||
|
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
|
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||||
|
ps = ps or phonemize(text, lang)
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
elif len(tokens) > 510:
|
||||||
|
tokens = tokens[:510]
|
||||||
|
print('Truncated to 510 tokens')
|
||||||
|
ref_s = voicepack[len(tokens)]
|
||||||
|
out = forward(model, tokens, ref_s, speed)
|
||||||
|
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||||
|
return out, ps
|
||||||
|
|
||||||
|
def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||||
|
ps = ps or phonemize(text, lang)
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
outs = []
|
||||||
|
loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
|
||||||
|
for i in range(loop_count):
|
||||||
|
ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
|
||||||
|
out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
|
||||||
|
outs.append(out)
|
||||||
|
outs = np.concatenate(outs)
|
||||||
|
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||||
|
return outs, ps
|
||||||
373
backend/python/kokoro/models.py
Normal file
373
backend/python/kokoro/models.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||||
|
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/models.py
|
||||||
|
from istftnet import AdaIN1d, Decoder
|
||||||
|
from munch import Munch
|
||||||
|
from pathlib import Path
|
||||||
|
from plbert import load_plbert
|
||||||
|
from torch.nn.utils import weight_norm, spectral_norm
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class LinearNorm(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
|
super(LinearNorm, self).__init__()
|
||||||
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
|
torch.nn.init.xavier_uniform_(
|
||||||
|
self.linear_layer.weight,
|
||||||
|
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(n_symbols, channels)
|
||||||
|
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.cnn = nn.ModuleList()
|
||||||
|
for _ in range(depth):
|
||||||
|
self.cnn.append(nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||||
|
LayerNorm(channels),
|
||||||
|
actv,
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
))
|
||||||
|
# self.cnn = nn.Sequential(*self.cnn)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
|
def forward(self, x, input_lengths, m):
|
||||||
|
x = self.embedding(x) # [B, T, emb]
|
||||||
|
x = x.transpose(1, 2) # [B, emb, T]
|
||||||
|
m = m.to(input_lengths.device).unsqueeze(1)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
for c in self.cnn:
|
||||||
|
x = c(x)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
x = x.transpose(1, 2) # [B, T, chn]
|
||||||
|
|
||||||
|
input_lengths = input_lengths.cpu().numpy()
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def inference(self, x):
|
||||||
|
x = self.embedding(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.cnn(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, layer_type):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = layer_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_type == 'none':
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
class AdainResBlk1d(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||||
|
upsample='none', dropout_p=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.actv = actv
|
||||||
|
self.upsample_type = upsample
|
||||||
|
self.upsample = UpSample1d(upsample)
|
||||||
|
self.learned_sc = dim_in != dim_out
|
||||||
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
if upsample == 'none':
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
|
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||||
|
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||||
|
if self.learned_sc:
|
||||||
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||||
|
|
||||||
|
def _shortcut(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
if self.learned_sc:
|
||||||
|
x = self.conv1x1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _residual(self, x, s):
|
||||||
|
x = self.norm1(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv1(self.dropout(x))
|
||||||
|
x = self.norm2(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.conv2(self.dropout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
out = self._residual(x, s)
|
||||||
|
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, style_dim, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.fc = nn.Linear(style_dim, channels*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||||
|
x = (1 + gamma) * x + beta
|
||||||
|
return x.transpose(1, -1).transpose(-1, -2)
|
||||||
|
|
||||||
|
class ProsodyPredictor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
||||||
|
d_model=d_hid,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||||
|
|
||||||
|
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.F0 = nn.ModuleList()
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
|
self.N = nn.ModuleList()
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
|
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, texts, style, text_lengths, alignment, m):
|
||||||
|
d = self.text_encoder(texts, style, text_lengths, m)
|
||||||
|
|
||||||
|
batch_size = d.shape[0]
|
||||||
|
text_size = d.shape[1]
|
||||||
|
|
||||||
|
# predict duration
|
||||||
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
d, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
|
||||||
|
m = m.to(text_lengths.device).unsqueeze(1)
|
||||||
|
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
|
||||||
|
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :x.shape[1], :] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
||||||
|
|
||||||
|
en = (d.transpose(-1, -2) @ alignment)
|
||||||
|
|
||||||
|
return duration.squeeze(-1), en
|
||||||
|
|
||||||
|
def F0Ntrain(self, x, s):
|
||||||
|
x, _ = self.shared(x.transpose(-1, -2))
|
||||||
|
|
||||||
|
F0 = x.transpose(-1, -2)
|
||||||
|
for block in self.F0:
|
||||||
|
F0 = block(F0, s)
|
||||||
|
F0 = self.F0_proj(F0)
|
||||||
|
|
||||||
|
N = x.transpose(-1, -2)
|
||||||
|
for block in self.N:
|
||||||
|
N = block(N, s)
|
||||||
|
N = self.N_proj(N)
|
||||||
|
|
||||||
|
return F0.squeeze(1), N.squeeze(1)
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
class DurationEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.lstms = nn.ModuleList()
|
||||||
|
for _ in range(nlayers):
|
||||||
|
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
||||||
|
d_model // 2,
|
||||||
|
num_layers=1,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
dropout=dropout))
|
||||||
|
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||||
|
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
self.d_model = d_model
|
||||||
|
self.sty_dim = sty_dim
|
||||||
|
|
||||||
|
def forward(self, x, style, text_lengths, m):
|
||||||
|
masks = m.to(text_lengths.device)
|
||||||
|
|
||||||
|
x = x.permute(2, 0, 1)
|
||||||
|
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
|
x = torch.cat([x, s], axis=-1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||||
|
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
for block in self.lstms:
|
||||||
|
if isinstance(block, AdaLayerNorm):
|
||||||
|
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||||
|
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||||
|
else:
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
block.flatten_parameters()
|
||||||
|
x, _ = block(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
return x.transpose(-1, -2)
|
||||||
|
|
||||||
|
def inference(self, x, style):
|
||||||
|
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
||||||
|
style = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
|
x = torch.cat([x, style], axis=-1)
|
||||||
|
src = self.pos_encoder(x)
|
||||||
|
output = self.transformer_encoder(src).transpose(0, 1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
||||||
|
def recursive_munch(d):
|
||||||
|
if isinstance(d, dict):
|
||||||
|
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
||||||
|
elif isinstance(d, list):
|
||||||
|
return [recursive_munch(v) for v in d]
|
||||||
|
else:
|
||||||
|
return d
|
||||||
|
|
||||||
|
def build_model(path, device):
|
||||||
|
config = Path(__file__).parent / 'config.json'
|
||||||
|
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||||
|
with open(config, 'r') as r:
|
||||||
|
args = recursive_munch(json.load(r))
|
||||||
|
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||||
|
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||||
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||||
|
upsample_rates = args.decoder.upsample_rates,
|
||||||
|
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
||||||
|
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
||||||
|
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
||||||
|
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
||||||
|
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
||||||
|
bert = load_plbert()
|
||||||
|
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
||||||
|
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
||||||
|
for child in parent.children():
|
||||||
|
if isinstance(child, nn.RNNBase):
|
||||||
|
child.flatten_parameters()
|
||||||
|
model = Munch(
|
||||||
|
bert=bert.to(device).eval(),
|
||||||
|
bert_encoder=bert_encoder.to(device).eval(),
|
||||||
|
predictor=predictor.to(device).eval(),
|
||||||
|
decoder=decoder.to(device).eval(),
|
||||||
|
text_encoder=text_encoder.to(device).eval(),
|
||||||
|
)
|
||||||
|
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
||||||
|
assert key in model, key
|
||||||
|
try:
|
||||||
|
model[key].load_state_dict(state_dict)
|
||||||
|
except:
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
model[key].load_state_dict(state_dict, strict=False)
|
||||||
|
return model
|
||||||
16
backend/python/kokoro/plbert.py
Normal file
16
backend/python/kokoro/plbert.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# https://huggingface.co/hexgrad/Kokoro-82M/blob/main/plbert.py
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||||
|
from transformers import AlbertConfig, AlbertModel
|
||||||
|
|
||||||
|
class CustomAlbert(AlbertModel):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
# Call the original forward method
|
||||||
|
outputs = super().forward(*args, **kwargs)
|
||||||
|
# Only return the last_hidden_state
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
def load_plbert():
|
||||||
|
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
||||||
|
albert_base_configuration = AlbertConfig(**plbert_config)
|
||||||
|
bert = CustomAlbert(albert_base_configuration)
|
||||||
|
return bert
|
||||||
@@ -8,6 +8,4 @@ else
|
|||||||
source $backend_dir/../common/libbackend.sh
|
source $backend_dir/../common/libbackend.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ensureVenv
|
|
||||||
|
|
||||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||||
@@ -1,6 +1,2 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
torch==2.4.1
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
|
||||||
torch
|
|
||||||
kokoro
|
|
||||||
soundfile
|
|
||||||
@@ -1,7 +1,3 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch==2.7.1+cu118
|
torch==2.4.1+cu118
|
||||||
torchaudio==2.7.1+cu118
|
transformers
|
||||||
transformers
|
|
||||||
accelerate
|
|
||||||
kokoro
|
|
||||||
soundfile
|
|
||||||
@@ -1,6 +1,2 @@
|
|||||||
torch==2.7.1
|
torch==2.4.1
|
||||||
torchaudio==2.7.1
|
transformers
|
||||||
transformers
|
|
||||||
accelerate
|
|
||||||
kokoro
|
|
||||||
soundfile
|
|
||||||
@@ -1,7 +1,3 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch==2.7.1+rocm6.3
|
torch==2.4.1+rocm6.0
|
||||||
torchaudio==2.7.1+rocm6.3
|
transformers
|
||||||
transformers
|
|
||||||
accelerate
|
|
||||||
kokoro
|
|
||||||
soundfile
|
|
||||||
@@ -1,11 +1,5 @@
|
|||||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
intel-extension-for-pytorch==2.8.10+xpu
|
intel-extension-for-pytorch==2.3.110+xpu
|
||||||
torch==2.5.1+cxx11.abi
|
torch==2.3.1+cxx11.abi
|
||||||
oneccl_bind_pt==2.8.0+xpu
|
oneccl_bind_pt==2.3.100+xpu
|
||||||
torchaudio==2.5.1+cxx11.abi
|
transformers
|
||||||
optimum[openvino]
|
|
||||||
setuptools
|
|
||||||
transformers==4.48.3
|
|
||||||
accelerate
|
|
||||||
kokoro
|
|
||||||
soundfile
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
grpcio==1.71.0
|
grpcio==1.71.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
phonemizer
|
||||||
packaging==24.1
|
scipy
|
||||||
pip
|
munch
|
||||||
chardet
|
setuptools
|
||||||
|
soundfile
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(30)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the Kokoro pipeline is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
|
|
||||||
print(response)
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Kokoro TTS pipeline loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the TTS generation works successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(
|
|
||||||
text="Kokoro is an open-weight TTS model with 82 million parameters.",
|
|
||||||
voice="af_heart",
|
|
||||||
dst="test_output.wav"
|
|
||||||
)
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
self.assertTrue(tts_response.success)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
.DEFAULT_GOAL := install
|
|
||||||
|
|
||||||
.PHONY: install
|
|
||||||
install:
|
|
||||||
bash install.sh
|
|
||||||
$(MAKE) protogen
|
|
||||||
|
|
||||||
.PHONY: protogen
|
|
||||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
.PHONY: protogen-clean
|
|
||||||
protogen-clean:
|
|
||||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
backend_pb2_grpc.py backend_pb2.py:
|
|
||||||
bash protogen.sh
|
|
||||||
|
|
||||||
.PHONY: clean
|
|
||||||
clean: protogen-clean
|
|
||||||
rm -rf venv __pycache__
|
|
||||||
@@ -1,174 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
gRPC server for RFDETR object detection models.
|
|
||||||
"""
|
|
||||||
from concurrent import futures
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import base64
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
import supervision as sv
|
|
||||||
from inference import get_model
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
A gRPC servicer for the RFDETR backend service.
|
|
||||||
|
|
||||||
This class implements the gRPC methods for object detection using RFDETR models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.model = None
|
|
||||||
self.model_name = None
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that returns the health status of the backend service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A HealthMessage object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Reply object that contains the health status of the backend service.
|
|
||||||
"""
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that loads a RFDETR model into memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A ModelOptions object that contains the model parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object that contains the result of the LoadModel operation.
|
|
||||||
"""
|
|
||||||
model_name = request.Model
|
|
||||||
try:
|
|
||||||
# Load the RFDETR model
|
|
||||||
self.model = get_model(model_name)
|
|
||||||
self.model_name = model_name
|
|
||||||
print(f'Loaded RFDETR model: {model_name}')
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Failed to load model: {err}")
|
|
||||||
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def Detect(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that performs object detection on an image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A DetectOptions object that contains the image source.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A DetectResponse object that contains the detection results.
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
print(f"Model is None")
|
|
||||||
return backend_pb2.DetectResponse()
|
|
||||||
print(f"Model is not None")
|
|
||||||
try:
|
|
||||||
print(f"Decoding image")
|
|
||||||
# Decode the base64 image
|
|
||||||
print(f"Image data: {request.src}")
|
|
||||||
|
|
||||||
image_data = base64.b64decode(request.src)
|
|
||||||
image = Image.open(BytesIO(image_data))
|
|
||||||
|
|
||||||
# Perform inference
|
|
||||||
predictions = self.model.infer(image, confidence=0.5)[0]
|
|
||||||
|
|
||||||
# Convert to proto format
|
|
||||||
proto_detections = []
|
|
||||||
for i in range(len(predictions.predictions)):
|
|
||||||
pred = predictions.predictions[i]
|
|
||||||
print(f"Prediction: {pred}")
|
|
||||||
proto_detection = backend_pb2.Detection(
|
|
||||||
x=float(pred.x),
|
|
||||||
y=float(pred.y),
|
|
||||||
width=float(pred.width),
|
|
||||||
height=float(pred.height),
|
|
||||||
confidence=float(pred.confidence),
|
|
||||||
class_name=pred.class_name
|
|
||||||
)
|
|
||||||
proto_detections.append(proto_detection)
|
|
||||||
|
|
||||||
return backend_pb2.DetectResponse(Detections=proto_detections)
|
|
||||||
except Exception as err:
|
|
||||||
print(f"Detection error: {err}")
|
|
||||||
return backend_pb2.DetectResponse()
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that returns the status of the backend service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A HealthMessage object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A StatusResponse object that contains the status information.
|
|
||||||
"""
|
|
||||||
state = backend_pb2.StatusResponse.READY if self.model is not None else backend_pb2.StatusResponse.UNINITIALIZED
|
|
||||||
return backend_pb2.StatusResponse(state=state)
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
||||||
options=[
|
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
|
||||||
])
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("[RFDETR] Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("[RFDETR] Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the RFDETR gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
print(f"[RFDETR] startup: {args}", file=sys.stderr)
|
|
||||||
serve(args.addr)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
|
||||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
|
||||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
|
||||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
|
||||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
|
||||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
|
||||||
fi
|
|
||||||
|
|
||||||
installRequirements
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
rfdetr
|
|
||||||
opencv-python
|
|
||||||
accelerate
|
|
||||||
peft
|
|
||||||
inference
|
|
||||||
torch==2.7.1
|
|
||||||
optimum-quanto
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
|
||||||
torch==2.7.1+cu118
|
|
||||||
rfdetr
|
|
||||||
opencv-python
|
|
||||||
accelerate
|
|
||||||
inference
|
|
||||||
peft
|
|
||||||
optimum-quanto
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
torch==2.7.1
|
|
||||||
rfdetr
|
|
||||||
opencv-python
|
|
||||||
accelerate
|
|
||||||
inference
|
|
||||||
peft
|
|
||||||
optimum-quanto
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
|
||||||
torch==2.7.1+rocm6.3
|
|
||||||
torchvision==0.22.1+rocm6.3
|
|
||||||
rfdetr
|
|
||||||
opencv-python
|
|
||||||
accelerate
|
|
||||||
inference
|
|
||||||
peft
|
|
||||||
optimum-quanto
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
|
||||||
intel-extension-for-pytorch==2.3.110+xpu
|
|
||||||
torch==2.3.1+cxx11.abi
|
|
||||||
torchvision==0.18.1+cxx11.abi
|
|
||||||
oneccl_bind_pt==2.3.100+xpu
|
|
||||||
optimum[openvino]
|
|
||||||
setuptools
|
|
||||||
rfdetr
|
|
||||||
inference
|
|
||||||
opencv-python
|
|
||||||
accelerate
|
|
||||||
peft
|
|
||||||
optimum-quanto
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
grpcio==1.71.0
|
|
||||||
protobuf
|
|
||||||
grpcio-tools
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
startBackend $@
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
backend_dir=$(dirname $0)
|
|
||||||
if [ -d $backend_dir/common ]; then
|
|
||||||
source $backend_dir/common/libbackend.sh
|
|
||||||
else
|
|
||||||
source $backend_dir/../common/libbackend.sh
|
|
||||||
fi
|
|
||||||
|
|
||||||
runUnittests
|
|
||||||
@@ -22,7 +22,7 @@ import torch.cuda
|
|||||||
|
|
||||||
XPU=os.environ.get("XPU", "0") == "1"
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import outetts
|
import outetts
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
@@ -90,7 +90,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
self.CUDA = torch.cuda.is_available()
|
self.CUDA = torch.cuda.is_available()
|
||||||
self.OV=False
|
self.OV=False
|
||||||
self.OuteTTS=False
|
self.OuteTTS=False
|
||||||
self.DiaTTS=False
|
|
||||||
self.SentenceTransformer = False
|
self.SentenceTransformer = False
|
||||||
|
|
||||||
device_map="cpu"
|
device_map="cpu"
|
||||||
@@ -98,30 +97,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
quantization = None
|
quantization = None
|
||||||
autoTokenizer = True
|
autoTokenizer = True
|
||||||
|
|
||||||
# Parse options from request.Options
|
|
||||||
self.options = {}
|
|
||||||
options = request.Options
|
|
||||||
|
|
||||||
# The options are a list of strings in this form optname:optvalue
|
|
||||||
# We are storing all the options in a dict so we can use it later when generating
|
|
||||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
|
||||||
for opt in options:
|
|
||||||
if ":" not in opt:
|
|
||||||
continue
|
|
||||||
key, value = opt.split(":", 1)
|
|
||||||
# if value is a number, convert it to the appropriate type
|
|
||||||
try:
|
|
||||||
if "." in value:
|
|
||||||
value = float(value)
|
|
||||||
else:
|
|
||||||
value = int(value)
|
|
||||||
except ValueError:
|
|
||||||
# Keep as string if conversion fails
|
|
||||||
pass
|
|
||||||
self.options[key] = value
|
|
||||||
|
|
||||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
|
||||||
|
|
||||||
if self.CUDA:
|
if self.CUDA:
|
||||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||||
if request.MainGPU:
|
if request.MainGPU:
|
||||||
@@ -227,16 +202,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
autoTokenizer = False
|
autoTokenizer = False
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||||
elif request.Type == "DiaForConditionalGeneration":
|
|
||||||
autoTokenizer = False
|
|
||||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
if self.CUDA:
|
|
||||||
self.model = self.model.to("cuda")
|
|
||||||
self.processor = self.processor.to("cuda")
|
|
||||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
|
||||||
self.DiaTTS = True
|
|
||||||
elif request.Type == "OuteTTS":
|
elif request.Type == "OuteTTS":
|
||||||
autoTokenizer = False
|
autoTokenizer = False
|
||||||
options = request.Options
|
options = request.Options
|
||||||
@@ -297,7 +262,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||||
self.max_tokens = self.model.config.max_position_embeddings
|
self.max_tokens = self.model.config.max_position_embeddings
|
||||||
else:
|
else:
|
||||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
self.max_tokens = 512
|
||||||
|
|
||||||
if autoTokenizer:
|
if autoTokenizer:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||||
@@ -520,15 +485,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tokens = 256
|
||||||
if request.HasField('duration'):
|
if request.HasField('duration'):
|
||||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||||
guidance = self.options.get("guidance_scale", 3.0)
|
guidance = 3.0
|
||||||
if request.HasField('temperature'):
|
if request.HasField('temperature'):
|
||||||
guidance = request.temperature
|
guidance = request.temperature
|
||||||
dosample = self.options.get("do_sample", True)
|
dosample = True
|
||||||
if request.HasField('sample'):
|
if request.HasField('sample'):
|
||||||
dosample = request.sample
|
dosample = request.sample
|
||||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
|
||||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||||
@@ -540,59 +506,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(success=True)
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
def OuteTTS(self, request, context):
|
||||||
def CallDiaTTS(self, request, context):
|
|
||||||
"""
|
|
||||||
Generates dialogue audio using the Dia model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A TTSRequest containing text dialogue and generation parameters
|
|
||||||
context: The gRPC context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object indicating success or failure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
|
||||||
text = [request.text]
|
|
||||||
|
|
||||||
# Process the input
|
|
||||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
|
||||||
|
|
||||||
# Generate audio with parameters from options or defaults
|
|
||||||
generation_params = {
|
|
||||||
**inputs,
|
|
||||||
"max_new_tokens": self.max_tokens,
|
|
||||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
|
||||||
"temperature": self.options.get("temperature", 1.8),
|
|
||||||
"top_p": self.options.get("top_p", 0.90),
|
|
||||||
"top_k": self.options.get("top_k", 45)
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs = self.model.generate(**generation_params)
|
|
||||||
|
|
||||||
# Decode and save audio
|
|
||||||
outputs = self.processor.batch_decode(outputs)
|
|
||||||
self.processor.save_audio(outputs, request.dst)
|
|
||||||
|
|
||||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
|
||||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
|
||||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
|
||||||
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(success=True)
|
|
||||||
|
|
||||||
|
|
||||||
def CallOuteTTS(self, request, context):
|
|
||||||
try:
|
try:
|
||||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||||
gen_cfg = outetts.GenerationConfig(
|
gen_cfg = outetts.GenerationConfig(
|
||||||
text="Speech synthesis is the artificial production of human speech.",
|
text="Speech synthesis is the artificial production of human speech.",
|
||||||
temperature=self.options.get("temperature", 0.1),
|
temperature=0.1,
|
||||||
repetition_penalty=self.options.get("repetition_penalty", 1.1),
|
repetition_penalty=1.1,
|
||||||
max_length=self.max_tokens,
|
max_length=self.max_tokens,
|
||||||
speaker=self.speaker,
|
speaker=self.speaker,
|
||||||
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||||
@@ -608,11 +528,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
if self.OuteTTS:
|
if self.OuteTTS:
|
||||||
return self.CallOuteTTS(request, context)
|
return self.OuteTTS(request, context)
|
||||||
|
|
||||||
if self.DiaTTS:
|
|
||||||
print("DiaTTS", file=sys.stderr)
|
|
||||||
return self.CallDiaTTS(request, context)
|
|
||||||
|
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
torch==2.7.1
|
torch==2.4.1
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
sentence-transformers==5.0.0
|
sentence-transformers==3.4.1
|
||||||
protobuf==6.31.0
|
protobuf==6.31.0
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch==2.7.1+cu118
|
torch==2.4.1+cu118
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
sentence-transformers==5.0.0
|
sentence-transformers==4.1.0
|
||||||
protobuf==6.31.0
|
protobuf==6.31.0
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
torch==2.7.1
|
torch==2.4.1
|
||||||
accelerate
|
accelerate
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
sentence-transformers==5.0.0
|
sentence-transformers==4.1.0
|
||||||
protobuf==6.31.0
|
protobuf==6.31.0
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch==2.7.1+rocm6.3
|
torch==2.4.1+rocm6.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
@@ -7,5 +7,5 @@ numba==0.60.0
|
|||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.0.0
|
sentence-transformers==4.1.0
|
||||||
protobuf==6.31.0
|
protobuf==6.31.0
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
intel-extension-for-pytorch==2.3.110+xpu
|
intel-extension-for-pytorch==2.3.110+xpu
|
||||||
torch==2.5.1+cxx11.abi
|
torch==2.3.1+cxx11.abi
|
||||||
oneccl_bind_pt==2.8.0+xpu
|
oneccl_bind_pt==2.3.100+xpu
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
@@ -9,5 +9,5 @@ transformers
|
|||||||
intel-extension-for-transformers
|
intel-extension-for-transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
sentence-transformers==5.0.0
|
sentence-transformers==4.1.0
|
||||||
protobuf==6.31.0
|
protobuf==6.31.0
|
||||||
@@ -2,8 +2,8 @@ package application
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Application struct {
|
type Application struct {
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
|
|
||||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
||||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -55,14 +55,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
if err := pkgStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||||
log.Error().Err(err).Msg("error installing models")
|
log.Error().Err(err).Msg("error installing models")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, backend := range options.ExternalBackends {
|
if err := pkgStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, options.ExternalBackends...); err != nil {
|
||||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, backend, "", ""); err != nil {
|
log.Error().Err(err).Msg("error installing external backends")
|
||||||
log.Error().Err(err).Msg("error installing external backend")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Detection(
|
|
||||||
sourceFile string,
|
|
||||||
loader *model.ModelLoader,
|
|
||||||
appConfig *config.ApplicationConfig,
|
|
||||||
backendConfig config.BackendConfig,
|
|
||||||
) (*proto.DetectResponse, error) {
|
|
||||||
opts := ModelOptions(backendConfig, appConfig)
|
|
||||||
detectionModel, err := loader.Load(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer loader.Close()
|
|
||||||
|
|
||||||
if detectionModel == nil {
|
|
||||||
return nil, fmt.Errorf("could not load detection model")
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
|
||||||
Src: sourceFile,
|
|
||||||
})
|
|
||||||
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||||
|
|
||||||
opts := ModelOptions(backendConfig, appConfig)
|
opts := ModelOptions(backendConfig, appConfig)
|
||||||
inferenceModel, err := loader.Load(
|
inferenceModel, err := loader.Load(
|
||||||
@@ -33,7 +33,6 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||||||
Dst: dst,
|
Dst: dst,
|
||||||
Src: src,
|
Src: src,
|
||||||
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||||
RefImages: refImages,
|
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
"github.com/mudler/LocalAI/pkg/startup"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/schollz/progressbar/v3"
|
"github.com/schollz/progressbar/v3"
|
||||||
)
|
)
|
||||||
@@ -23,9 +23,7 @@ type BackendsList struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BackendsInstall struct {
|
type BackendsInstall struct {
|
||||||
BackendArgs string `arg:"" optional:"" name:"backend" help:"Backend configuration URL to load"`
|
BackendArgs []string `arg:"" optional:"" name:"backends" help:"Backend configuration URLs to load"`
|
||||||
Name string `arg:"" optional:"" name:"name" help:"Name of the backend"`
|
|
||||||
Alias string `arg:"" optional:"" name:"alias" help:"Alias of the backend"`
|
|
||||||
|
|
||||||
BackendsCMDFlags `embed:""`
|
BackendsCMDFlags `embed:""`
|
||||||
}
|
}
|
||||||
@@ -68,25 +66,27 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
|||||||
log.Error().Err(err).Msg("unable to load galleries")
|
log.Error().Err(err).Msg("unable to load galleries")
|
||||||
}
|
}
|
||||||
|
|
||||||
progressBar := progressbar.NewOptions(
|
for _, backendName := range bi.BackendArgs {
|
||||||
1000,
|
|
||||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)),
|
progressBar := progressbar.NewOptions(
|
||||||
progressbar.OptionShowBytes(false),
|
1000,
|
||||||
progressbar.OptionClearOnFinish(),
|
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", backendName)),
|
||||||
)
|
progressbar.OptionShowBytes(false),
|
||||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
progressbar.OptionClearOnFinish(),
|
||||||
v := int(percentage * 10)
|
)
|
||||||
err := progressBar.Set(v)
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||||
|
v := int(percentage * 10)
|
||||||
|
err := progressBar.Set(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, backendName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar")
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
|
"github.com/mudler/LocalAI/pkg/startup"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/schollz/progressbar/v3"
|
"github.com/schollz/progressbar/v3"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type RunCMD struct {
|
|||||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
|
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
|
||||||
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
||||||
|
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
|
||||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||||
@@ -87,6 +88,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
||||||
config.WithGeneratedContentDir(r.GeneratedContentPath),
|
config.WithGeneratedContentDir(r.GeneratedContentPath),
|
||||||
config.WithUploadDir(r.UploadPath),
|
config.WithUploadDir(r.UploadPath),
|
||||||
|
config.WithConfigsDir(r.ConfigPath),
|
||||||
config.WithDynamicConfigDir(r.LocalaiConfigDir),
|
config.WithDynamicConfigDir(r.LocalaiConfigDir),
|
||||||
config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval),
|
config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval),
|
||||||
config.WithF16(r.F16),
|
config.WithF16(r.F16),
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (u *CreateOCIImageCMD) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
|
func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
|
||||||
if len(u.Args) == 0 {
|
if u.Args == nil || len(u.Args) == 0 {
|
||||||
return fmt.Errorf("no GGUF file provided")
|
return fmt.Errorf("no GGUF file provided")
|
||||||
}
|
}
|
||||||
// We try to guess only if we don't have a template defined already
|
// We try to guess only if we don't have a template defined already
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ type ApplicationConfig struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
GeneratedContentDir string
|
GeneratedContentDir string
|
||||||
|
|
||||||
UploadDir string
|
ConfigsDir string
|
||||||
|
UploadDir string
|
||||||
|
|
||||||
DynamicConfigsDir string
|
DynamicConfigsDir string
|
||||||
DynamicConfigsDirPollInterval time.Duration
|
DynamicConfigsDirPollInterval time.Duration
|
||||||
@@ -301,6 +302,12 @@ func WithUploadDir(uploadDir string) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithConfigsDir(configsDir string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.ConfigsDir = configsDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||||
return func(o *ApplicationConfig) {
|
return func(o *ApplicationConfig) {
|
||||||
o.DynamicConfigsDir = dynamicConfigsDir
|
o.DynamicConfigsDir = dynamicConfigsDir
|
||||||
|
|||||||
@@ -458,7 +458,6 @@ const (
|
|||||||
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
||||||
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
||||||
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
||||||
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
|
|
||||||
|
|
||||||
// Common Subsets
|
// Common Subsets
|
||||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||||
@@ -480,7 +479,6 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
|||||||
"FLAG_VAD": FLAG_VAD,
|
"FLAG_VAD": FLAG_VAD,
|
||||||
"FLAG_LLM": FLAG_LLM,
|
"FLAG_LLM": FLAG_LLM,
|
||||||
"FLAG_VIDEO": FLAG_VIDEO,
|
"FLAG_VIDEO": FLAG_VIDEO,
|
||||||
"FLAG_DETECTION": FLAG_DETECTION,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -574,12 +572,6 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
|
||||||
if c.Backend != "rfdetr" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||||
if c.Backend != "transformers-musicgen" {
|
if c.Backend != "transformers-musicgen" {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ package gallery
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/core/system"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BackendMetadata represents the metadata stored in a JSON file for each installed backend
|
// BackendMetadata represents the metadata stored in a JSON file for each installed backend
|
||||||
@@ -24,7 +23,6 @@ type GalleryBackend struct {
|
|||||||
Metadata `json:",inline" yaml:",inline"`
|
Metadata `json:",inline" yaml:",inline"`
|
||||||
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
|
||||||
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
|
||||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
|
||||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,11 +33,9 @@ func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.Syste
|
|||||||
|
|
||||||
realBackend := backend.CapabilitiesMap[systemState.Capability(backend.CapabilitiesMap)]
|
realBackend := backend.CapabilitiesMap[systemState.Capability(backend.CapabilitiesMap)]
|
||||||
if realBackend == "" {
|
if realBackend == "" {
|
||||||
log.Debug().Str("backend", backend.Name).Str("reportedCapability", systemState.Capability(backend.CapabilitiesMap)).Msg("No backend found for reported capability")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Str("backend", backend.Name).Str("reportedCapability", systemState.Capability(backend.CapabilitiesMap)).Msg("Found backend for reported capability")
|
|
||||||
return backends.FindByName(realBackend)
|
return backends.FindByName(realBackend)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/system"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
|
||||||
cp "github.com/otiai10/copy"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -146,28 +145,8 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
|
|||||||
}
|
}
|
||||||
|
|
||||||
uri := downloader.URI(config.URI)
|
uri := downloader.URI(config.URI)
|
||||||
// Check if it is a directory
|
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||||
if uri.LooksLikeDir() {
|
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
|
||||||
// It is a directory, we just copy it over in the backend folder
|
|
||||||
if err := cp.Copy(config.URI, backendPath); err != nil {
|
|
||||||
return fmt.Errorf("failed copying: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
uri := downloader.URI(config.URI)
|
|
||||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
|
||||||
success := false
|
|
||||||
// Try to download from mirrors
|
|
||||||
for _, mirror := range config.Mirrors {
|
|
||||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
|
||||||
success = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !success {
|
|
||||||
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create metadata for the backend
|
// Create metadata for the backend
|
||||||
@@ -250,22 +229,16 @@ func ListSystemBackends(basePath string) (map[string]string, error) {
|
|||||||
for _, backend := range backends {
|
for _, backend := range backends {
|
||||||
if backend.IsDir() {
|
if backend.IsDir() {
|
||||||
runFile := filepath.Join(basePath, backend.Name(), runFile)
|
runFile := filepath.Join(basePath, backend.Name(), runFile)
|
||||||
|
// Skip if metadata file don't exist
|
||||||
var metadata *BackendMetadata
|
|
||||||
|
|
||||||
// If metadata file does not exist, we just use the directory name
|
|
||||||
// and we do not fill the other metadata (such as potential backend Aliases)
|
|
||||||
metadataFilePath := filepath.Join(basePath, backend.Name(), metadataFile)
|
metadataFilePath := filepath.Join(basePath, backend.Name(), metadataFile)
|
||||||
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
|
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
|
||||||
metadata = &BackendMetadata{
|
continue
|
||||||
Name: backend.Name(),
|
}
|
||||||
}
|
|
||||||
} else {
|
// Check for alias in metadata
|
||||||
// Check for alias in metadata
|
metadata, err := readBackendMetadata(filepath.Join(basePath, backend.Name()))
|
||||||
metadata, err = readBackendMetadata(filepath.Join(basePath, backend.Name()))
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata == nil {
|
if metadata == nil {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/core/system"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
|
|||||||
|
|
||||||
if !strings.Contains(name, "@") {
|
if !strings.Contains(name, "@") {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if strings.EqualFold(strings.ToLower(m.GetName()), strings.ToLower(name)) {
|
if strings.EqualFold(m.GetName(), name) {
|
||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if strings.EqualFold(strings.ToLower(name), strings.ToLower(fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName()))) {
|
if strings.EqualFold(name, fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName())) {
|
||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
lconfig "github.com/mudler/LocalAI/core/config"
|
lconfig "github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/system"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/dave-gray101/v2keyauth"
|
"github.com/dave-gray101/v2keyauth"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/http/routes"
|
"github.com/mudler/LocalAI/core/http/routes"
|
||||||
|
|
||||||
@@ -197,6 +199,11 @@ func API(application *application.Application) (*fiber.App, error) {
|
|||||||
router.Use(csrf.New())
|
router.Use(csrf.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load config jsons
|
||||||
|
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||||
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||||
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||||
|
|
||||||
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
|
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
|
||||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, backendPath string
|
|||||||
|
|
||||||
// GetOpStatusEndpoint returns the job status
|
// GetOpStatusEndpoint returns the job status
|
||||||
// @Summary Returns the job status
|
// @Summary Returns the job status
|
||||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
// @Success 200 {object} services.BackendOpStatus "Response"
|
||||||
// @Router /backends/jobs/{uuid} [get]
|
// @Router /backends/jobs/{uuid} [get]
|
||||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
@@ -48,7 +48,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
|||||||
|
|
||||||
// GetAllStatusEndpoint returns all the jobs status progress
|
// GetAllStatusEndpoint returns all the jobs status progress
|
||||||
// @Summary Returns all the jobs status progress
|
// @Summary Returns all the jobs status progress
|
||||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
// @Success 200 {object} map[string]services.BackendOpStatus "Response"
|
||||||
// @Router /backends/jobs [get]
|
// @Router /backends/jobs [get]
|
||||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
@@ -58,7 +58,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
|||||||
|
|
||||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||||
// @Summary Install backends to LocalAI.
|
// @Summary Install backends to LocalAI.
|
||||||
// @Param request body GalleryBackend true "query params"
|
// @Param request body BackendModel true "query params"
|
||||||
// @Success 200 {object} schema.BackendResponse "Response"
|
// @Success 200 {object} schema.BackendResponse "Response"
|
||||||
// @Router /backends/apply [post]
|
// @Router /backends/apply [post]
|
||||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
|
||||||
// @Summary Detects objects in the input image.
|
|
||||||
// @Param request body schema.DetectionRequest true "query params"
|
|
||||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
|
||||||
// @Router /v1/detection [post]
|
|
||||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
|
||||||
if !ok || input.Model == "" {
|
|
||||||
return fiber.ErrBadRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
|
||||||
if !ok || cfg == nil {
|
|
||||||
return fiber.ErrBadRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
|
||||||
|
|
||||||
image, err := utils.GetContentURIAsBase64(input.Image)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := schema.DetectionResponse{
|
|
||||||
Detections: make([]schema.Detection, len(res.Detections)),
|
|
||||||
}
|
|
||||||
for i, detection := range res.Detections {
|
|
||||||
response.Detections[i] = schema.Detection{
|
|
||||||
X: detection.X,
|
|
||||||
Y: detection.Y,
|
|
||||||
Width: detection.Width,
|
|
||||||
Height: detection.Height,
|
|
||||||
ClassName: detection.ClassName,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -15,10 +15,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ModelGalleryEndpointService struct {
|
type ModelGalleryEndpointService struct {
|
||||||
galleries []config.Gallery
|
galleries []config.Gallery
|
||||||
backendGalleries []config.Gallery
|
modelPath string
|
||||||
modelPath string
|
galleryApplier *services.GalleryService
|
||||||
galleryApplier *services.GalleryService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GalleryModel struct {
|
type GalleryModel struct {
|
||||||
@@ -26,12 +25,11 @@ type GalleryModel struct {
|
|||||||
gallery.GalleryModel
|
gallery.GalleryModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
func CreateModelGalleryEndpointService(galleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||||
return ModelGalleryEndpointService{
|
return ModelGalleryEndpointService{
|
||||||
galleries: galleries,
|
galleries: galleries,
|
||||||
backendGalleries: backendGalleries,
|
modelPath: modelPath,
|
||||||
modelPath: modelPath,
|
galleryApplier: galleryApplier,
|
||||||
galleryApplier: galleryApplier,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,7 +79,6 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
|
|||||||
ID: uuid.String(),
|
ID: uuid.String(),
|
||||||
GalleryElementName: input.ID,
|
GalleryElementName: input.ID,
|
||||||
Galleries: mgs.galleries,
|
Galleries: mgs.galleries,
|
||||||
BackendGalleries: mgs.backendGalleries,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||||
|
|||||||
522
core/http/endpoints/openai/assistant.go
Normal file
522
core/http/endpoints/openai/assistant.go
Normal file
@@ -0,0 +1,522 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/microcosm-cc/bluemonday"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/core/services"
|
||||||
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolType defines a type for tool options
|
||||||
|
type ToolType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CodeInterpreter ToolType = "code_interpreter"
|
||||||
|
Retrieval ToolType = "retrieval"
|
||||||
|
Function ToolType = "function"
|
||||||
|
|
||||||
|
MaxCharacterInstructions = 32768
|
||||||
|
MaxCharacterDescription = 512
|
||||||
|
MaxCharacterName = 256
|
||||||
|
MaxToolsSize = 128
|
||||||
|
MaxFileIdSize = 20
|
||||||
|
MaxCharacterMetadataKey = 64
|
||||||
|
MaxCharacterMetadataValue = 512
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type ToolType `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assistant represents the structure of an assistant object from the OpenAI API.
|
||||||
|
type Assistant struct {
|
||||||
|
ID string `json:"id"` // The unique identifier of the assistant.
|
||||||
|
Object string `json:"object"` // Object type, which is "assistant".
|
||||||
|
Created int64 `json:"created"` // The time at which the assistant was created.
|
||||||
|
Model string `json:"model"` // The model ID used by the assistant.
|
||||||
|
Name string `json:"name,omitempty"` // The name of the assistant.
|
||||||
|
Description string `json:"description,omitempty"` // The description of the assistant.
|
||||||
|
Instructions string `json:"instructions,omitempty"` // The system instructions that the assistant uses.
|
||||||
|
Tools []Tool `json:"tools,omitempty"` // A list of tools enabled on the assistant.
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"` // A list of file IDs attached to this assistant.
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"` // Set of key-value pairs attached to the assistant.
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
Assistants = []Assistant{} // better to return empty array instead of "null"
|
||||||
|
AssistantsConfigFile = "assistants.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AssistantRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Instructions string `json:"instructions,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAssistantEndpoint is the OpenAI Assistant API endpoint https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
||||||
|
// @Summary Create an assistant with a model and instructions.
|
||||||
|
// @Param request body AssistantRequest true "query params"
|
||||||
|
// @Success 200 {object} Assistant "Response"
|
||||||
|
// @Router /v1/assistants [post]
|
||||||
|
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(AssistantRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelExists(cl, ml, request.Model) {
|
||||||
|
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Model %q not found", request.Model)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Tools == nil {
|
||||||
|
request.Tools = []Tool{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.FileIDs == nil {
|
||||||
|
request.FileIDs = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Metadata == nil {
|
||||||
|
request.Metadata = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := "asst_" + strconv.FormatInt(generateRandomID(), 10)
|
||||||
|
|
||||||
|
assistant := Assistant{
|
||||||
|
ID: id,
|
||||||
|
Object: "assistant",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: request.Model,
|
||||||
|
Name: request.Name,
|
||||||
|
Description: request.Description,
|
||||||
|
Instructions: request.Instructions,
|
||||||
|
Tools: request.Tools,
|
||||||
|
FileIDs: request.FileIDs,
|
||||||
|
Metadata: request.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
Assistants = append(Assistants, assistant)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentId int64 = 0
|
||||||
|
|
||||||
|
func generateRandomID() int64 {
|
||||||
|
atomic.AddInt64(¤tId, 1)
|
||||||
|
return currentId
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAssistantsEndpoint is the OpenAI Assistant API endpoint to list assistents https://platform.openai.com/docs/api-reference/assistants/listAssistants
|
||||||
|
// @Summary List available assistents
|
||||||
|
// @Param limit query int false "Limit the number of assistants returned"
|
||||||
|
// @Param order query string false "Order of assistants returned"
|
||||||
|
// @Param after query string false "Return assistants created after the given ID"
|
||||||
|
// @Param before query string false "Return assistants created before the given ID"
|
||||||
|
// @Success 200 {object} []Assistant "Response"
|
||||||
|
// @Router /v1/assistants [get]
|
||||||
|
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
// Because we're altering the existing assistants list we should just duplicate it for now.
|
||||||
|
returnAssistants := Assistants
|
||||||
|
// Parse query parameters
|
||||||
|
limitQuery := c.Query("limit", "20")
|
||||||
|
orderQuery := c.Query("order", "desc")
|
||||||
|
afterQuery := c.Query("after")
|
||||||
|
beforeQuery := c.Query("before")
|
||||||
|
|
||||||
|
// Convert string limit to integer
|
||||||
|
limit, err := strconv.Atoi(limitQuery)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(http.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Invalid limit query value: %s", limitQuery)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort assistants
|
||||||
|
sort.SliceStable(returnAssistants, func(i, j int) bool {
|
||||||
|
if orderQuery == "asc" {
|
||||||
|
return returnAssistants[i].Created < returnAssistants[j].Created
|
||||||
|
}
|
||||||
|
return returnAssistants[i].Created > returnAssistants[j].Created
|
||||||
|
})
|
||||||
|
|
||||||
|
// After and before cursors
|
||||||
|
if afterQuery != "" {
|
||||||
|
returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery)
|
||||||
|
}
|
||||||
|
if beforeQuery != "" {
|
||||||
|
returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit
|
||||||
|
if limit < len(returnAssistants) {
|
||||||
|
returnAssistants = returnAssistants[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(returnAssistants)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID
|
||||||
|
// We assume that the assistants are already sorted
|
||||||
|
func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant {
|
||||||
|
idInt, err := strconv.Atoi(id)
|
||||||
|
if err != nil {
|
||||||
|
return assistants // Return original slice if invalid id format is provided
|
||||||
|
}
|
||||||
|
|
||||||
|
var filteredAssistants []Assistant
|
||||||
|
|
||||||
|
for _, assistant := range assistants {
|
||||||
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip if invalid id in assistant
|
||||||
|
}
|
||||||
|
|
||||||
|
if aid < idInt {
|
||||||
|
filteredAssistants = append(filteredAssistants, assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredAssistants
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID
|
||||||
|
// We assume that the assistants are already sorted
|
||||||
|
func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
|
||||||
|
idInt, err := strconv.Atoi(id)
|
||||||
|
if err != nil {
|
||||||
|
return assistants // Return original slice if invalid id format is provided
|
||||||
|
}
|
||||||
|
|
||||||
|
var filteredAssistants []Assistant
|
||||||
|
|
||||||
|
for _, assistant := range assistants {
|
||||||
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip if invalid id in assistant
|
||||||
|
}
|
||||||
|
|
||||||
|
if aid > idInt {
|
||||||
|
filteredAssistants = append(filteredAssistants, assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredAssistants
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
|
||||||
|
found = false
|
||||||
|
models, err := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if model == modelName {
|
||||||
|
found = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAssistantEndpoint is the OpenAI Assistant API endpoint to delete assistents https://platform.openai.com/docs/api-reference/assistants/deleteAssistant
|
||||||
|
// @Summary Delete assistents
|
||||||
|
// @Success 200 {object} schema.DeleteAssistantResponse "Response"
|
||||||
|
// @Router /v1/assistants/{assistant_id} [delete]
|
||||||
|
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantResponse{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: "assistant.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID)
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantResponse{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: "assistant.deleted",
|
||||||
|
Deleted: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAssistantEndpoint is the OpenAI Assistant API endpoint to get assistents https://platform.openai.com/docs/api-reference/assistants/getAssistant
|
||||||
|
// @Summary Get assistent data
|
||||||
|
// @Success 200 {object} Assistant "Response"
|
||||||
|
// @Router /v1/assistants/{assistant_id} [get]
|
||||||
|
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantFile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
AssistantID string `json:"assistant_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
AssistantFiles []AssistantFile
|
||||||
|
AssistantsFileConfigFile = "assistantsFile.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(schema.AssistantFileRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
if len(assistant.FileIDs) > MaxFileIdSize {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range UploadedFiles {
|
||||||
|
if file.ID == request.FileID {
|
||||||
|
assistant.FileIDs = append(assistant.FileIDs, request.FileID)
|
||||||
|
assistantFile := AssistantFile{
|
||||||
|
ID: file.ID,
|
||||||
|
Object: "assistant.file",
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
AssistantID: assistant.ID,
|
||||||
|
}
|
||||||
|
AssistantFiles = append(AssistantFiles, assistantFile)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find file_id: %s", request.FileID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find %q", assistantID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
type ListAssistantFiles struct {
|
||||||
|
Data []schema.File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
limitQuery := c.Query("limit", "20")
|
||||||
|
order := c.Query("order", "desc")
|
||||||
|
limit, err := strconv.Atoi(limitQuery)
|
||||||
|
if err != nil || limit < 1 || limit > 100 {
|
||||||
|
limit = 20 // Default to 20 if there's an error or the limit is out of bounds
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort files by CreatedAt depending on the order query parameter
|
||||||
|
if order == "asc" {
|
||||||
|
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||||
|
return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt
|
||||||
|
})
|
||||||
|
} else { // default to "desc"
|
||||||
|
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||||
|
return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of files returned
|
||||||
|
var limitedFiles []AssistantFile
|
||||||
|
hasMore := false
|
||||||
|
if len(AssistantFiles) > limit {
|
||||||
|
hasMore = true
|
||||||
|
limitedFiles = AssistantFiles[:limit]
|
||||||
|
} else {
|
||||||
|
limitedFiles = AssistantFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"object": "list",
|
||||||
|
"data": limitedFiles,
|
||||||
|
"first_id": func() string {
|
||||||
|
if len(limitedFiles) > 0 {
|
||||||
|
return limitedFiles[0].ID
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
"last_id": func() string {
|
||||||
|
if len(limitedFiles) > 0 {
|
||||||
|
return limitedFiles[len(limitedFiles)-1].ID
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
"has_more": hasMore,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusOK).JSON(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(AssistantRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
newAssistant := Assistant{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: assistant.Object,
|
||||||
|
Created: assistant.Created,
|
||||||
|
Model: request.Model,
|
||||||
|
Name: request.Name,
|
||||||
|
Description: request.Description,
|
||||||
|
Instructions: request.Instructions,
|
||||||
|
Tools: request.Tools,
|
||||||
|
FileIDs: request.FileIDs, // todo: should probably verify fileids exist
|
||||||
|
Metadata: request.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old one and replace with new one
|
||||||
|
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||||
|
Assistants = append(Assistants, newAssistant)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(newAssistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
fileId := c.Params("file_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||||
|
}
|
||||||
|
// First remove file from assistant
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
for j, fileId := range assistant.FileIDs {
|
||||||
|
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...)
|
||||||
|
|
||||||
|
// Check if the file exists in the assistantFiles slice
|
||||||
|
for i, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.ID == fileId {
|
||||||
|
// Remove the file from the assistantFiles slice
|
||||||
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID)
|
||||||
|
for i, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.AssistantID == assistantID {
|
||||||
|
|
||||||
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Warn().Msgf("Unable to find assistant: %s", assistantID)
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
fileId := c.Params("file_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.AssistantID == assistantID {
|
||||||
|
if assistantFile.ID == fileId {
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
460
core/http/endpoints/openai/assistant_test.go
Normal file
460
core/http/endpoints/openai/assistant_test.go
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var configsDir string = "/tmp/localai/configs"
|
||||||
|
|
||||||
|
type MockLoader struct {
|
||||||
|
models []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func tearDown() func() {
|
||||||
|
return func() {
|
||||||
|
UploadedFiles = []schema.File{}
|
||||||
|
Assistants = []Assistant{}
|
||||||
|
AssistantFiles = []AssistantFile{}
|
||||||
|
_ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile))
|
||||||
|
_ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssistantEndpoints(t *testing.T) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
cl := &config.BackendConfigLoader{}
|
||||||
|
//configsDir := "/tmp/localai/configs"
|
||||||
|
modelPath := "/tmp/localai/model"
|
||||||
|
var ml = model.NewModelLoader(modelPath, false)
|
||||||
|
|
||||||
|
appConfig := &config.ApplicationConfig{
|
||||||
|
ConfigsDir: configsDir,
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
ModelPath: modelPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(appConfig.ConfigsDir)
|
||||||
|
_ = os.MkdirAll(appConfig.ConfigsDir, 0750)
|
||||||
|
_ = os.MkdirAll(modelPath, 0750)
|
||||||
|
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j"))
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
app.Post("/files", UploadFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
t.Run("CreateAssistantEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, resp, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(Assistants))
|
||||||
|
//t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||||
|
|
||||||
|
assert.Equal(t, ar.Name, resultAssistant.Name)
|
||||||
|
assert.Equal(t, ar.Model, resultAssistant.Model)
|
||||||
|
assert.Equal(t, ar.Tools, resultAssistant.Tools)
|
||||||
|
assert.Equal(t, ar.Description, resultAssistant.Description)
|
||||||
|
assert.Equal(t, ar.Instructions, resultAssistant.Instructions)
|
||||||
|
assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs)
|
||||||
|
assert.Equal(t, ar.Metadata, resultAssistant.Metadata)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ListAssistantsEndpoint", func(t *testing.T) {
|
||||||
|
var ids []string
|
||||||
|
var resultAssistant []Assistant
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: fmt.Sprintf("3.5-turbo-%d", i),
|
||||||
|
Description: fmt.Sprintf("Test Assistant - %d", i),
|
||||||
|
Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i),
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: []string{"fid-1234"},
|
||||||
|
Metadata: map[string]string{"meta": "data"},
|
||||||
|
}
|
||||||
|
|
||||||
|
//var err error
|
||||||
|
ra, _, err := createAssistant(app, *ar)
|
||||||
|
// Because we create the assistants so fast all end up with the same created time.
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
resultAssistant = append(resultAssistant, ra)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
ids = append(ids, resultAssistant[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, ids))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
reqURL string
|
||||||
|
expectedStatus int
|
||||||
|
expectedResult []Assistant
|
||||||
|
expectedStringResult string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Usage - limit only",
|
||||||
|
reqURL: "/assistants?limit=2",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants[:2], // Expecting the first two assistants
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - order asc",
|
||||||
|
reqURL: "/assistants?order=asc",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants, // Expecting all assistants in ascending order
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - order desc",
|
||||||
|
reqURL: "/assistants?order=desc",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - after specific ID",
|
||||||
|
reqURL: "/assistants?after=2",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
// Note this is correct because it's put in descending order already
|
||||||
|
expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - before specific ID",
|
||||||
|
reqURL: "/assistants?before=4",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3.
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Usage - non-integer limit",
|
||||||
|
reqURL: "/assistants?limit=two",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedStringResult: "Invalid limit query value: two",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Usage - non-existing id in after",
|
||||||
|
reqURL: "/assistants?after=100",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
||||||
|
if tt.expectedStatus != fiber.StatusOK {
|
||||||
|
all, _ := io.ReadAll(response.Body)
|
||||||
|
assert.Equal(t, tt.expectedStringResult, string(all))
|
||||||
|
} else {
|
||||||
|
var result []Assistant
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&result)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedResult, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DeleteAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
_, err = app.Test(deleteReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(Assistants))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var getAssistant Assistant
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&getAssistant)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ModifyAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
modifiedAr := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "4.0-turbo",
|
||||||
|
Description: "Modified Test Assistant",
|
||||||
|
Instructions: "You are math teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: CodeInterpreter}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedArJson, err := json.Marshal(modifiedAr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
|
||||||
|
modifyResponse, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var getAssistant Assistant
|
||||||
|
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
||||||
|
|
||||||
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change
|
||||||
|
assert.Equal(t, modifiedAr.Tools, getAssistant.Tools)
|
||||||
|
assert.Equal(t, modifiedAr.Name, getAssistant.Name)
|
||||||
|
assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions)
|
||||||
|
assert.Equal(t, modifiedAr.Description, getAssistant.Description)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CreateAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("ListAssistantFilesEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("GetAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID))
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var assistantFile AssistantFile
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&assistantFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, af.ID, assistantFile.ID)
|
||||||
|
assert.Equal(t, af.AssistantID, assistantFile.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := schema.AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cleanupAssistantFile(t, app, af.ID, af.AssistantID)()
|
||||||
|
|
||||||
|
assert.Empty(t, AssistantFiles)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (schema.File, Assistant, error) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
assistant, _, err := createAssistant(app, *ar)
|
||||||
|
if err != nil {
|
||||||
|
return schema.File{}, Assistant{}, err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID}))
|
||||||
|
|
||||||
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_, err := CallFilesDeleteEndpoint(t, app, file.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
return file, assistant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAssistantFile(app *fiber.App, afr schema.AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) {
|
||||||
|
afrJson, err := json.Marshal(afr)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files", assistantId)
|
||||||
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var assistantFile AssistantFile
|
||||||
|
all, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
|
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return assistantFile, resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) {
|
||||||
|
assistant, err := json.Marshal(ar)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyString, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultAssistant Assistant
|
||||||
|
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
||||||
|
return resultAssistant, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
||||||
|
return func() {
|
||||||
|
for _, assistant := range ids {
|
||||||
|
target := fmt.Sprintf("/assistants/%s", assistant)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
_, err := app.Test(deleteReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to delete assistant %s: %v", assistant, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() {
|
||||||
|
return func() {
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId)
|
||||||
|
request := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var dafr schema.DeleteAssistantFileResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&dafr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, dafr.Deleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
@@ -175,7 +175,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||||||
textContentToReturn = ""
|
textContentToReturn = ""
|
||||||
id = uuid.New().String()
|
id = uuid.New().String()
|
||||||
created = int(time.Now().Unix())
|
created = int(time.Now().Unix())
|
||||||
|
|
||||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||||
if !ok || input.Model == "" {
|
if !ok || input.Model == "" {
|
||||||
return fiber.ErrBadRequest
|
return fiber.ErrBadRequest
|
||||||
@@ -305,7 +305,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||||||
// If we are using the tokenizer template, we don't need to process the messages
|
// If we are using the tokenizer template, we don't need to process the messages
|
||||||
// unless we are processing functions
|
// unless we are processing functions
|
||||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||||
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
|
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
if config.Grammar != "" {
|
if config.Grammar != "" {
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ import (
|
|||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
@@ -109,10 +109,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
ReasoningEffort: input.ReasoningEffort,
|
|
||||||
Metadata: input.Metadata,
|
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
predInput = templatedInput
|
predInput = templatedInput
|
||||||
@@ -162,10 +160,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||||||
|
|
||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
Input: i,
|
Input: i,
|
||||||
ReasoningEffort: input.ReasoningEffort,
|
|
||||||
Metadata: input.Metadata,
|
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
i = templatedInput
|
i = templatedInput
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -47,11 +47,9 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
Input: i,
|
Input: i,
|
||||||
Instruction: input.Instruction,
|
Instruction: input.Instruction,
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
ReasoningEffort: input.ReasoningEffort,
|
|
||||||
Metadata: input.Metadata,
|
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
i = templatedInput
|
i = templatedInput
|
||||||
|
|||||||
194
core/http/endpoints/openai/files.go
Normal file
194
core/http/endpoints/openai/files.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/microcosm-cc/bluemonday"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var UploadedFiles []schema.File
|
||||||
|
|
||||||
|
const UploadedFilesFile = "uploadedFiles.json"
|
||||||
|
|
||||||
|
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||||
|
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the file size
|
||||||
|
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
|
||||||
|
}
|
||||||
|
|
||||||
|
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||||
|
if purpose == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the filename to prevent directory traversal
|
||||||
|
filename := utils.SanitizeFileName(file.Filename)
|
||||||
|
|
||||||
|
savePath := filepath.Join(appConfig.UploadDir, filename)
|
||||||
|
|
||||||
|
// Check if file already exists
|
||||||
|
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.SaveFile(file, savePath)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
f := schema.File{
|
||||||
|
ID: fmt.Sprintf("file-%d", getNextFileId()),
|
||||||
|
Object: "file",
|
||||||
|
Bytes: int(file.Size),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Filename: file.Filename,
|
||||||
|
Purpose: purpose,
|
||||||
|
}
|
||||||
|
|
||||||
|
UploadedFiles = append(UploadedFiles, f)
|
||||||
|
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentFileId int64 = 0
|
||||||
|
|
||||||
|
func getNextFileId() int64 {
|
||||||
|
atomic.AddInt64(¤tId, 1)
|
||||||
|
return currentId
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||||
|
// @Summary List files.
|
||||||
|
// @Success 200 {object} schema.ListFiles "Response"
|
||||||
|
// @Router /v1/files [get]
|
||||||
|
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
var listFiles schema.ListFiles
|
||||||
|
|
||||||
|
purpose := c.Query("purpose")
|
||||||
|
if purpose == "" {
|
||||||
|
listFiles.Data = UploadedFiles
|
||||||
|
} else {
|
||||||
|
for _, f := range UploadedFiles {
|
||||||
|
if purpose == f.Purpose {
|
||||||
|
listFiles.Data = append(listFiles.Data, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
listFiles.Object = "list"
|
||||||
|
return c.Status(fiber.StatusOK).JSON(listFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFileFromRequest(c *fiber.Ctx) (*schema.File, error) {
|
||||||
|
id := c.Params("file_id")
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("file_id parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range UploadedFiles {
|
||||||
|
if id == f.ID {
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to find file id %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesEndpoint is the OpenAI API endpoint to get files https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
|
// @Summary Returns information about a specific file.
|
||||||
|
// @Success 200 {object} schema.File "Response"
|
||||||
|
// @Router /v1/files/{file_id} [get]
|
||||||
|
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeleteStatus struct {
|
||||||
|
Id string
|
||||||
|
Object string
|
||||||
|
Deleted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilesEndpoint is the OpenAI API endpoint to delete files https://platform.openai.com/docs/api-reference/files/delete
|
||||||
|
// @Summary Delete a file.
|
||||||
|
// @Success 200 {object} DeleteStatus "Response"
|
||||||
|
// @Router /v1/files/{file_id} [delete]
|
||||||
|
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
// If the file doesn't exist then we should just continue to remove it
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove upload from list
|
||||||
|
for i, f := range UploadedFiles {
|
||||||
|
if f.ID == file.ID {
|
||||||
|
UploadedFiles = append(UploadedFiles[:i], UploadedFiles[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||||
|
return c.JSON(DeleteStatus{
|
||||||
|
Id: file.ID,
|
||||||
|
Object: "file",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesContentsEndpoint is the OpenAI API endpoint to get files content https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
|
// @Summary Returns information about a specific file.
|
||||||
|
// @Success 200 {string} binary "file"
|
||||||
|
// @Router /v1/files/{file_id}/content [get]
|
||||||
|
// GetFilesContentsEndpoint
|
||||||
|
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Send(fileContents)
|
||||||
|
}
|
||||||
|
}
|
||||||
301
core/http/endpoints/openai/files_test.go
Normal file
301
core/http/endpoints/openai/files_test.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
utils2 "github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader = &config.BackendConfigLoader{}
|
||||||
|
|
||||||
|
option = &config.ApplicationConfig{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app = fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader := &config.BackendConfigLoader{}
|
||||||
|
|
||||||
|
option := &config.ApplicationConfig{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
fmt.Println(f1)
|
||||||
|
fmt.Printf("ERror: %v\n", err)
|
||||||
|
fmt.Printf("resp: %+v\n", resp)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
// Check if file exists in the disk
|
||||||
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
fileName := testName + "-test.txt"
|
||||||
|
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName(fileName))
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
|
||||||
|
assert.False(t, os.IsNotExist(err))
|
||||||
|
assert.Equal(t, file.Bytes, 5242880)
|
||||||
|
assert.NotEmpty(t, file.CreatedAt)
|
||||||
|
assert.Equal(t, file.Filename, fileName)
|
||||||
|
assert.Equal(t, file.Purpose, "fine-tune")
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != len(UploadedFiles) {
|
||||||
|
t.Errorf("Expected %v files, got %v files", len(UploadedFiles), len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != 1 {
|
||||||
|
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
req := httptest.NewRequest("GET", "/files", nil)
|
||||||
|
resp, _ := app.Test(req)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var listFiles schema.ListFiles
|
||||||
|
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
|
||||||
|
var target string
|
||||||
|
if purpose != "" {
|
||||||
|
target = fmt.Sprintf("/files?purpose=%s", purpose)
|
||||||
|
} else {
|
||||||
|
target = "/files"
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
|
||||||
|
return app.Test(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||||
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) schema.File {
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
f := responseToFile(t, resp)
|
||||||
|
|
||||||
|
//id := f.ID
|
||||||
|
//t.Cleanup(func() {
|
||||||
|
// _, err := CallFilesDeleteEndpoint(t, app, id)
|
||||||
|
// assert.NoError(t, err)
|
||||||
|
// assert.Empty(t, UploadedFiles)
|
||||||
|
//})
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
target := fmt.Sprintf("/files/%s", fileId)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create multi-part file
|
||||||
|
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
|
||||||
|
body := new(strings.Builder)
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
file, _ := os.Open(filePath)
|
||||||
|
defer file.Close()
|
||||||
|
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
|
||||||
|
io.Copy(part, file)
|
||||||
|
|
||||||
|
if purpose != "" {
|
||||||
|
_ = writer.WriteField("purpose", purpose)
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Close()
|
||||||
|
return strings.NewReader(body.String()), writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create test files
|
||||||
|
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
|
||||||
|
err := os.MkdirAll(option.UploadDir, 0750)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
t.Fatalf("Error MKDIR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Create(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(name)
|
||||||
|
os.RemoveAll(option.UploadDir)
|
||||||
|
})
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToString(resp *http.Response, t *testing.T) string {
|
||||||
|
return string(bodyToByteArray(resp, t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToFile(t *testing.T, resp *http.Response) schema.File {
|
||||||
|
var file schema.File
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToListFile(t *testing.T, resp *http.Response) schema.ListFiles {
|
||||||
|
var listFiles schema.ListFiles
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to decode response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return listFiles
|
||||||
|
}
|
||||||
@@ -79,37 +79,49 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
|||||||
return fiber.ErrBadRequest
|
return fiber.ErrBadRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process input images (for img2img/inpainting)
|
|
||||||
src := ""
|
src := ""
|
||||||
if input.File != "" {
|
if input.File != "" {
|
||||||
src = processImageFile(input.File, appConfig.GeneratedContentDir)
|
|
||||||
if src != "" {
|
|
||||||
defer os.RemoveAll(src)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process multiple input images
|
fileData := []byte{}
|
||||||
var inputImages []string
|
var err error
|
||||||
if len(input.Files) > 0 {
|
// check if input.File is an URL, if so download it and save it
|
||||||
for _, file := range input.Files {
|
// to a temporary file
|
||||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||||
if processedFile != "" {
|
out, err := downloadFile(input.File)
|
||||||
inputImages = append(inputImages, processedFile)
|
if err != nil {
|
||||||
defer os.RemoveAll(processedFile)
|
return fmt.Errorf("failed downloading file:%w", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(out)
|
||||||
|
|
||||||
|
fileData, err = os.ReadFile(out)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading file:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// base 64 decode the file and write it somewhere
|
||||||
|
// that we will cleanup
|
||||||
|
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Process reference images
|
// Create a temporary file
|
||||||
var refImages []string
|
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
|
||||||
if len(input.RefImages) > 0 {
|
if err != nil {
|
||||||
for _, file := range input.RefImages {
|
return err
|
||||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
|
||||||
if processedFile != "" {
|
|
||||||
refImages = append(refImages, processedFile)
|
|
||||||
defer os.RemoveAll(processedFile)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// write the base64 result
|
||||||
|
writer := bufio.NewWriter(outputFile)
|
||||||
|
_, err = writer.Write(fileData)
|
||||||
|
if err != nil {
|
||||||
|
outputFile.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
outputFile.Close()
|
||||||
|
src = outputFile.Name()
|
||||||
|
defer os.RemoveAll(src)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
@@ -190,13 +202,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
|||||||
|
|
||||||
baseURL := c.BaseURL()
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
// Use the first input image as src if available, otherwise use the original src
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||||
inputSrc := src
|
|
||||||
if len(inputImages) > 0 {
|
|
||||||
inputSrc = inputImages[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -237,51 +243,3 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
|||||||
return c.JSON(resp)
|
return c.JSON(resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// processImageFile handles a single image file (URL or base64) and returns the path to the temporary file
|
|
||||||
func processImageFile(file string, generatedContentDir string) string {
|
|
||||||
fileData := []byte{}
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// check if file is an URL, if so download it and save it to a temporary file
|
|
||||||
if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") {
|
|
||||||
out, err := downloadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("Failed downloading file: %s", file)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(out)
|
|
||||||
|
|
||||||
fileData, err = os.ReadFile(out)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("Failed reading downloaded file: %s", out)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// base 64 decode the file and write it somewhere that we will cleanup
|
|
||||||
fileData, err = base64.StdEncoding.DecodeString(file)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("Failed decoding base64 file")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := os.CreateTemp(generatedContentDir, "b64")
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed creating temporary file")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// write the base64 result
|
|
||||||
writer := bufio.NewWriter(outputFile)
|
|
||||||
_, err = writer.Write(fileData)
|
|
||||||
if err != nil {
|
|
||||||
outputFile.Close()
|
|
||||||
log.Error().Err(err).Msg("Failed writing to temporary file")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
outputFile.Close()
|
|
||||||
|
|
||||||
return outputFile.Name()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,12 +16,12 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/application"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/sound"
|
"github.com/mudler/LocalAI/pkg/sound"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
@@ -29,8 +29,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
localSampleRate = 16000
|
localSampleRate = 16000
|
||||||
remoteSampleRate = 24000
|
remoteSampleRate = 24000
|
||||||
)
|
)
|
||||||
|
|
||||||
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
||||||
@@ -210,9 +210,9 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
|
|||||||
// TODO: Need some way to pass this to the backend
|
// TODO: Need some way to pass this to the backend
|
||||||
Threshold: 0.5,
|
Threshold: 0.5,
|
||||||
// TODO: This is ignored and the amount of padding is random at present
|
// TODO: This is ignored and the amount of padding is random at present
|
||||||
PrefixPaddingMs: 30,
|
PrefixPaddingMs: 30,
|
||||||
SilenceDurationMs: 500,
|
SilenceDurationMs: 500,
|
||||||
CreateResponse: func() *bool { t := true; return &t }(),
|
CreateResponse: func() *bool { t := true; return &t }(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
InputAudioTranscription: &types.InputAudioTranscription{
|
InputAudioTranscription: &types.InputAudioTranscription{
|
||||||
@@ -233,7 +233,7 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
|
|||||||
// TODO: The API has no way to configure the VAD model or other models that make up a pipeline to fake any-to-any
|
// TODO: The API has no way to configure the VAD model or other models that make up a pipeline to fake any-to-any
|
||||||
// So possibly we could have a way to configure a composite model that can be used in situations where any-to-any is expected
|
// So possibly we could have a way to configure a composite model that can be used in situations where any-to-any is expected
|
||||||
pipeline := config.Pipeline{
|
pipeline := config.Pipeline{
|
||||||
VAD: "silero-vad",
|
VAD: "silero-vad",
|
||||||
Transcription: session.InputAudioTranscription.Model,
|
Transcription: session.InputAudioTranscription.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,8 +567,8 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
|
|||||||
trCur := session.InputAudioTranscription
|
trCur := session.InputAudioTranscription
|
||||||
|
|
||||||
if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model {
|
if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model {
|
||||||
pipeline := config.Pipeline{
|
pipeline := config.Pipeline {
|
||||||
VAD: "silero-vad",
|
VAD: "silero-vad",
|
||||||
Transcription: trUpd.Model,
|
Transcription: trUpd.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,7 +684,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
|||||||
sendEvent(c, types.InputAudioBufferClearedEvent{
|
sendEvent(c, types.InputAudioBufferClearedEvent{
|
||||||
ServerEventBase: types.ServerEventBase{
|
ServerEventBase: types.ServerEventBase{
|
||||||
EventID: "event_TODO",
|
EventID: "event_TODO",
|
||||||
Type: types.ServerEventTypeInputAudioBufferCleared,
|
Type: types.ServerEventTypeInputAudioBufferCleared,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -697,7 +697,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
|||||||
sendEvent(c, types.InputAudioBufferSpeechStartedEvent{
|
sendEvent(c, types.InputAudioBufferSpeechStartedEvent{
|
||||||
ServerEventBase: types.ServerEventBase{
|
ServerEventBase: types.ServerEventBase{
|
||||||
EventID: "event_TODO",
|
EventID: "event_TODO",
|
||||||
Type: types.ServerEventTypeInputAudioBufferSpeechStarted,
|
Type: types.ServerEventTypeInputAudioBufferSpeechStarted,
|
||||||
},
|
},
|
||||||
AudioStartMs: time.Now().Sub(startTime).Milliseconds(),
|
AudioStartMs: time.Now().Sub(startTime).Milliseconds(),
|
||||||
})
|
})
|
||||||
@@ -719,7 +719,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
|||||||
sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{
|
sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{
|
||||||
ServerEventBase: types.ServerEventBase{
|
ServerEventBase: types.ServerEventBase{
|
||||||
EventID: "event_TODO",
|
EventID: "event_TODO",
|
||||||
Type: types.ServerEventTypeInputAudioBufferSpeechStopped,
|
Type: types.ServerEventTypeInputAudioBufferSpeechStopped,
|
||||||
},
|
},
|
||||||
AudioEndMs: time.Now().Sub(startTime).Milliseconds(),
|
AudioEndMs: time.Now().Sub(startTime).Milliseconds(),
|
||||||
})
|
})
|
||||||
@@ -728,9 +728,9 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
|||||||
sendEvent(c, types.InputAudioBufferCommittedEvent{
|
sendEvent(c, types.InputAudioBufferCommittedEvent{
|
||||||
ServerEventBase: types.ServerEventBase{
|
ServerEventBase: types.ServerEventBase{
|
||||||
EventID: "event_TODO",
|
EventID: "event_TODO",
|
||||||
Type: types.ServerEventTypeInputAudioBufferCommitted,
|
Type: types.ServerEventTypeInputAudioBufferCommitted,
|
||||||
},
|
},
|
||||||
ItemID: generateItemID(),
|
ItemID: generateItemID(),
|
||||||
PreviousItemID: "TODO",
|
PreviousItemID: "TODO",
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -833,9 +833,9 @@ func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig,
|
|||||||
|
|
||||||
func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADSegment, error) {
|
func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADSegment, error) {
|
||||||
soundIntBuffer := &audio.IntBuffer{
|
soundIntBuffer := &audio.IntBuffer{
|
||||||
Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1},
|
Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1},
|
||||||
SourceBitDepth: 16,
|
SourceBitDepth: 16,
|
||||||
Data: sound.ConvertInt16ToInt(adata),
|
Data: sound.ConvertInt16ToInt(adata),
|
||||||
}
|
}
|
||||||
|
|
||||||
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
|||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
if !appConfig.DisableGalleryEndpoint {
|
if !appConfig.DisableGalleryEndpoint {
|
||||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||||
|
|
||||||
@@ -41,11 +41,6 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
|||||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
router.Post("/v1/detection",
|
|
||||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
|
|
||||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
|
|
||||||
localai.DetectionEndpoint(cl, ml, appConfig))
|
|
||||||
|
|
||||||
router.Post("/tts",
|
router.Post("/tts",
|
||||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user