mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 11:13:31 -05:00
Compare commits
86 Commits
fix/ci-503
...
v3.10.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
923ebbb344 | ||
|
|
ea51567b89 | ||
|
|
552c62a19c | ||
|
|
c0b21a921b | ||
|
|
b10045adc2 | ||
|
|
61b5e3b629 | ||
|
|
e35d7cb3b3 | ||
|
|
0fa0ac4797 | ||
|
|
be7ed85838 | ||
|
|
c12b310028 | ||
|
|
0447d5564d | ||
|
|
22c0eb5421 | ||
|
|
a0a00fb937 | ||
|
|
6dd44742ea | ||
|
|
00c72e7d3e | ||
|
|
d01c335cf6 | ||
|
|
5687df4535 | ||
|
|
f5fade97e6 | ||
|
|
b88ae31e4e | ||
|
|
f6daaa7c35 | ||
|
|
c491c6ca90 | ||
|
|
34e054f607 | ||
|
|
e886bb291a | ||
|
|
4bf2f8bbd8 | ||
|
|
d3525b7509 | ||
|
|
c8aa821e0e | ||
|
|
b3191927ae | ||
|
|
54c5a2d9ea | ||
|
|
0279591fec | ||
|
|
8845186955 | ||
|
|
ab8ed24358 | ||
|
|
a021df5a88 | ||
|
|
5f403b1631 | ||
|
|
897ad1729e | ||
|
|
16a18a2e55 | ||
|
|
3387bfaee0 | ||
|
|
1cd33047b4 | ||
|
|
1de045311a | ||
|
|
5fe9bf9f84 | ||
|
|
d4fd0c0609 | ||
|
|
d16722ee13 | ||
|
|
1f10ab39a9 | ||
|
|
4d36e393d1 | ||
|
|
cb8616c7d1 | ||
|
|
ff31d50488 | ||
|
|
1a50717e33 | ||
|
|
49d6305509 | ||
|
|
d20a113aef | ||
|
|
cbaa793520 | ||
|
|
6fe3fc880f | ||
|
|
752e641c48 | ||
|
|
44d78b4d15 | ||
|
|
64d0a96ba3 | ||
|
|
b19afc9e64 | ||
|
|
d6e698876b | ||
|
|
8962205546 | ||
|
|
eddc460118 | ||
|
|
a6ff354c86 | ||
|
|
3a2be4df48 | ||
|
|
4e1f448e86 | ||
|
|
3e0168360a | ||
|
|
ea4157887b | ||
|
|
699c50be47 | ||
|
|
94eecc43a3 | ||
|
|
7e35ec6c4f | ||
|
|
7891c33cb1 | ||
|
|
271cc79709 | ||
|
|
3d12d5e70d | ||
|
|
bc180c2638 | ||
|
|
2de30440fe | ||
|
|
673a80a578 | ||
|
|
2554e9fabe | ||
|
|
5bfc3eebf8 | ||
|
|
ab893fe302 | ||
|
|
c88074a19e | ||
|
|
5ca8f0aea0 | ||
|
|
84234e531f | ||
|
|
4cbf9abfef | ||
|
|
fdc2c0737c | ||
|
|
f4b0a304d7 | ||
|
|
d16ec7aa9e | ||
|
|
d699b7ccdc | ||
|
|
a4d224dd1b | ||
|
|
917c7aa9f3 | ||
|
|
5aa66842dd | ||
|
|
f5dee90962 |
205
.github/workflows/backend.yml
vendored
205
.github/workflows/backend.yml
vendored
@@ -105,6 +105,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "9"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-qwen-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "9"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-pocket-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -124,7 +150,7 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp"
|
||||
@@ -340,6 +366,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-qwen-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-pocket-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -405,6 +457,32 @@ jobs:
|
||||
backend: "vibevoice"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-qwen-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-pocket-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -641,13 +719,39 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-qwen-tts'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-pocket-tts'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "faster-whisper"
|
||||
@@ -660,7 +764,7 @@ jobs:
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-coqui'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "coqui"
|
||||
@@ -772,6 +876,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-qwen-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-pocket-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -825,6 +955,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-qwen-tts'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-pocket-tts'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -885,7 +1041,7 @@ jobs:
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp"
|
||||
@@ -911,7 +1067,7 @@ jobs:
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp"
|
||||
@@ -1252,19 +1408,6 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-neutts'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "neutts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1278,6 +1421,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-qwen-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-pocket-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
backend-jobs-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
strategy:
|
||||
|
||||
40
.github/workflows/test-extra.yml
vendored
40
.github/workflows/test-extra.yml
vendored
@@ -265,4 +265,42 @@ jobs:
|
||||
- name: Test moonshine
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
||||
tests-pocket-tts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential ffmpeg
|
||||
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Test pocket-tts
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
|
||||
tests-qwen-tts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential ffmpeg
|
||||
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Test qwen-tts
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-tts
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-tts test
|
||||
22
Dockerfile
22
Dockerfile
@@ -10,7 +10,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl wget espeak-ng libgomp1 \
|
||||
ffmpeg libopenblas0 libopenblas-dev && \
|
||||
ffmpeg libopenblas0 libopenblas-dev sox && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -42,22 +42,22 @@ RUN <<EOT bash
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils mesa-vulkan-drivers
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.328.1/linux/vulkansdk-linux-x86_64-1.4.328.1.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.328.1 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.328.1 && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/share/* /usr/share/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
@@ -106,7 +106,7 @@ RUN <<EOT bash
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
|
||||
34
Makefile
34
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts backends/qwen-tts
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -9,7 +9,7 @@ LAUNCHER_BINARY_NAME=local-ai-launcher
|
||||
|
||||
CUDA_MAJOR_VERSION?=13
|
||||
CUDA_MINOR_VERSION?=0
|
||||
UBUNTU_VERSION?=2204
|
||||
UBUNTU_VERSION?=2404
|
||||
UBUNTU_CODENAME?=noble
|
||||
|
||||
GORELEASER?=
|
||||
@@ -316,6 +316,8 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/vllm
|
||||
$(MAKE) -C backend/python/vibevoice
|
||||
$(MAKE) -C backend/python/moonshine
|
||||
$(MAKE) -C backend/python/pocket-tts
|
||||
$(MAKE) -C backend/python/qwen-tts
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -324,6 +326,8 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/vllm test
|
||||
$(MAKE) -C backend/python/vibevoice test
|
||||
$(MAKE) -C backend/python/moonshine test
|
||||
$(MAKE) -C backend/python/pocket-tts test
|
||||
$(MAKE) -C backend/python/qwen-tts test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
DOCKER_AIO_IMAGE?=local-ai-aio
|
||||
@@ -447,17 +451,17 @@ BACKEND_FASTER_WHISPER = faster-whisper|python|.|false|true
|
||||
BACKEND_COQUI = coqui|python|.|false|true
|
||||
BACKEND_BARK = bark|python|.|false|true
|
||||
BACKEND_EXLLAMA2 = exllama2|python|.|false|true
|
||||
|
||||
# Python backends with ./backend context
|
||||
BACKEND_RFDETR = rfdetr|python|./backend|false|true
|
||||
BACKEND_KITTEN_TTS = kitten-tts|python|./backend|false|true
|
||||
BACKEND_NEUTTS = neutts|python|./backend|false|true
|
||||
BACKEND_KOKORO = kokoro|python|./backend|false|true
|
||||
BACKEND_VLLM = vllm|python|./backend|false|true
|
||||
BACKEND_DIFFUSERS = diffusers|python|./backend|--progress=plain|true
|
||||
BACKEND_CHATTERBOX = chatterbox|python|./backend|false|true
|
||||
BACKEND_VIBEVOICE = vibevoice|python|./backend|--progress=plain|true
|
||||
BACKEND_MOONSHINE = moonshine|python|./backend|false|true
|
||||
BACKEND_RFDETR = rfdetr|python|.|false|true
|
||||
BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true
|
||||
BACKEND_NEUTTS = neutts|python|.|false|true
|
||||
BACKEND_KOKORO = kokoro|python|.|false|true
|
||||
BACKEND_VLLM = vllm|python|.|false|true
|
||||
BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true
|
||||
BACKEND_CHATTERBOX = chatterbox|python|.|false|true
|
||||
BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
|
||||
BACKEND_MOONSHINE = moonshine|python|.|false|true
|
||||
BACKEND_POCKET_TTS = pocket-tts|python|.|false|true
|
||||
BACKEND_QWEN_TTS = qwen-tts|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -503,12 +507,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-exllama2 docker-build-moonshine
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-exllama2 docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts
|
||||
|
||||
########################################################
|
||||
### END Backends
|
||||
|
||||
22
README.md
22
README.md
@@ -111,6 +111,8 @@
|
||||
|
||||
## 💻 Quickstart
|
||||
|
||||
> ⚠️ **Note:** The `install.sh` script is currently experiencing issues due to the heavy changes currently undergoing in LocalAI and may produce broken or misconfigured installations. Please use Docker installation (see below) or manual binary installation until [issue #8032](https://github.com/mudler/LocalAI/issues/8032) is resolved.
|
||||
|
||||
Run the installer script:
|
||||
|
||||
```bash
|
||||
@@ -128,7 +130,7 @@ For more installation options, see [Installer Options](https://localai.io/instal
|
||||
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
|
||||
Or run with docker:
|
||||
### Containers (Docker, podman, ...)
|
||||
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
>
|
||||
@@ -137,13 +139,13 @@ Or run with docker:
|
||||
>
|
||||
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
||||
|
||||
### CPU only image:
|
||||
#### CPU only image:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
|
||||
```
|
||||
|
||||
### NVIDIA GPU Images:
|
||||
#### NVIDIA GPU Images:
|
||||
|
||||
```bash
|
||||
# CUDA 13.0
|
||||
@@ -160,25 +162,25 @@ docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nv
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
|
||||
```
|
||||
|
||||
### AMD GPU Images (ROCm):
|
||||
#### AMD GPU Images (ROCm):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
|
||||
```
|
||||
|
||||
### Intel GPU Images (oneAPI):
|
||||
#### Intel GPU Images (oneAPI):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
||||
```
|
||||
|
||||
### Vulkan GPU Images:
|
||||
#### Vulkan GPU Images:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
|
||||
```
|
||||
|
||||
### AIO Images (pre-downloaded models):
|
||||
#### AIO Images (pre-downloaded models):
|
||||
|
||||
```bash
|
||||
# CPU version
|
||||
@@ -295,6 +297,8 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
||||
| **silero-vad** | Voice Activity Detection | CPU |
|
||||
| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
|
||||
| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
|
||||
### Image & Video Generation
|
||||
| Backend | Description | Acceleration Support |
|
||||
@@ -316,8 +320,8 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
||||
|-------------------|-------------------|------------------|
|
||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts, vibevoice | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark, vibevoice | Intel Arc, Intel iGPUs |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts, vibevoice, pocket-tts, qwen-tts | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark, vibevoice, pocket-tts, qwen-tts | Intel Arc, Intel iGPUs |
|
||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
|
||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||
| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (AGX Orin, etc.) |
|
||||
|
||||
@@ -47,22 +47,22 @@ RUN <<EOT bash
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.328.1/linux/vulkansdk-linux-x86_64-1.4.328.1.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.328.1 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.328.1 && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/share/* /usr/share/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
@@ -94,7 +94,11 @@ RUN <<EOT bash
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
@@ -106,7 +110,7 @@ RUN <<EOT bash
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
|
||||
@@ -104,22 +104,22 @@ RUN <<EOT bash
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.328.1/linux/vulkansdk-linux-x86_64-1.4.328.1.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.328.1 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.328.1 && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/share/* /usr/share/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
@@ -148,11 +148,14 @@ RUN <<EOT bash
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
echo https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
@@ -164,7 +167,7 @@ RUN <<EOT bash
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
|
||||
@@ -61,22 +61,22 @@ RUN <<EOT bash
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.328.1/linux/vulkansdk-linux-x86_64-1.4.328.1.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.328.1.tar.xz && \
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.328.1 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.328.1 && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.328.1/x86_64/share/* /usr/share/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
@@ -108,7 +108,11 @@ RUN <<EOT bash
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
@@ -120,7 +124,7 @@ RUN <<EOT bash
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
|
||||
@@ -32,6 +32,8 @@ service Backend {
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -410,3 +412,8 @@ message Detection {
|
||||
message DetectResponse {
|
||||
repeated Detection Detections = 1;
|
||||
}
|
||||
|
||||
message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=ae9f8df77882716b1702df2bed8919499e64cc28
|
||||
LLAMA_VERSION?=a5eaa1d6a3732bc0f460b02b61c95680bba5a012
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <regex>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <signal.h>
|
||||
#include <thread>
|
||||
|
||||
@@ -82,8 +83,8 @@ static void start_llama_server(server_context& ctx_server) {
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
// LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
// common_chat_templates_source(ctx_server.impl->chat_templates.get()),
|
||||
// common_chat_format_example(ctx_server.impl->chat_templates.get(), ctx_server.impl->params_base.use_jinja).c_str(), ctx_server.impl->params_base.default_template_kwargs);
|
||||
// common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get()),
|
||||
// common_chat_format_example(ctx_server.impl->chat_params.tmpls.get(), ctx_server.impl->params_base.use_jinja).c_str(), ctx_server.impl->params_base.default_template_kwargs);
|
||||
|
||||
// Keep the chat templates initialized in load_model() so they can be used when UseTokenizerTemplate is enabled
|
||||
// Templates will only be used conditionally in Predict/PredictStream when UseTokenizerTemplate is true and Messages are provided
|
||||
@@ -390,8 +391,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// Initialize fit_params options (can be overridden by options)
|
||||
// fit_params: whether to auto-adjust params to fit device memory (default: true as in llama.cpp)
|
||||
params.fit_params = true;
|
||||
// fit_params_target: target margin per device in bytes (default: 1GB)
|
||||
params.fit_params_target = 1024 * 1024 * 1024;
|
||||
// fit_params_target: target margin per device in bytes (default: 1GB per device)
|
||||
// Initialize as vector with default value for all devices
|
||||
params.fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024 * 1024);
|
||||
// fit_params_min_ctx: minimum context size for fit (default: 4096)
|
||||
params.fit_params_min_ctx = 4096;
|
||||
|
||||
@@ -468,10 +470,28 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
} else if (!strcmp(optname, "fit_params_target") || !strcmp(optname, "fit_target")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
// Value is in MiB, convert to bytes
|
||||
params.fit_params_target = static_cast<size_t>(std::stoi(optval_str)) * 1024 * 1024;
|
||||
// Value is in MiB, can be comma-separated list for multiple devices
|
||||
// Single value is broadcast across all devices
|
||||
std::string arg_next = optval_str;
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
|
||||
std::vector<std::string> split_arg{ it, {} };
|
||||
if (split_arg.size() >= llama_max_devices()) {
|
||||
// Too many values provided
|
||||
continue;
|
||||
}
|
||||
if (split_arg.size() == 1) {
|
||||
// Single value: broadcast to all devices
|
||||
size_t value_mib = std::stoul(split_arg[0]);
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), value_mib * 1024 * 1024);
|
||||
} else {
|
||||
// Multiple values: set per device
|
||||
for (size_t i = 0; i < split_arg.size() && i < params.fit_params_target.size(); i++) {
|
||||
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024 * 1024;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (1GB)
|
||||
// If conversion fails, keep default value (1GB per device)
|
||||
}
|
||||
}
|
||||
} else if (!strcmp(optname, "fit_params_min_ctx") || !strcmp(optname, "fit_ctx")) {
|
||||
@@ -686,13 +706,13 @@ private:
|
||||
public:
|
||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) {
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
// Implement Health RPC
|
||||
reply->set_message("OK");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) {
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(ctx_server, request, params);
|
||||
@@ -709,11 +729,72 @@ public:
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
LOG_INF("\n");
|
||||
|
||||
// Capture error messages during model loading
|
||||
struct error_capture {
|
||||
std::string captured_error;
|
||||
std::mutex error_mutex;
|
||||
ggml_log_callback original_callback;
|
||||
void* original_user_data;
|
||||
} error_capture_data;
|
||||
|
||||
// Get original log callback
|
||||
llama_log_get(&error_capture_data.original_callback, &error_capture_data.original_user_data);
|
||||
|
||||
// Set custom callback to capture errors
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * user_data) {
|
||||
auto* capture = static_cast<error_capture*>(user_data);
|
||||
|
||||
// Capture error messages
|
||||
if (level == GGML_LOG_LEVEL_ERROR) {
|
||||
std::lock_guard<std::mutex> lock(capture->error_mutex);
|
||||
// Append error message, removing trailing newlines
|
||||
std::string msg(text);
|
||||
while (!msg.empty() && (msg.back() == '\n' || msg.back() == '\r')) {
|
||||
msg.pop_back();
|
||||
}
|
||||
if (!msg.empty()) {
|
||||
if (!capture->captured_error.empty()) {
|
||||
capture->captured_error.append("; ");
|
||||
}
|
||||
capture->captured_error.append(msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Also call original callback to preserve logging
|
||||
if (capture->original_callback) {
|
||||
capture->original_callback(level, text, capture->original_user_data);
|
||||
}
|
||||
}, &error_capture_data);
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
result->set_message("Failed loading model");
|
||||
bool load_success = ctx_server.load_model(params);
|
||||
|
||||
// Restore original log callback
|
||||
llama_log_set(error_capture_data.original_callback, error_capture_data.original_user_data);
|
||||
|
||||
if (!load_success) {
|
||||
std::string error_msg = "Failed to load model: " + params.model.path;
|
||||
if (!params.mmproj.path.empty()) {
|
||||
error_msg += " (with mmproj: " + params.mmproj.path + ")";
|
||||
}
|
||||
if (params.has_speculative() && !params.speculative.model.path.empty()) {
|
||||
error_msg += " (with draft model: " + params.speculative.model.path + ")";
|
||||
}
|
||||
|
||||
// Add captured error details if available
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(error_capture_data.error_mutex);
|
||||
if (!error_capture_data.captured_error.empty()) {
|
||||
error_msg += ". Error: " + error_capture_data.captured_error;
|
||||
} else {
|
||||
error_msg += ". Model file may not exist or be invalid.";
|
||||
}
|
||||
}
|
||||
|
||||
result->set_message(error_msg);
|
||||
result->set_success(false);
|
||||
return Status::CANCELLED;
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_msg);
|
||||
}
|
||||
|
||||
// Process grammar triggers now that vocab is available
|
||||
@@ -801,7 +882,7 @@ public:
|
||||
std::string prompt_str;
|
||||
std::vector<raw_buffer> files; // Declare files early so it's accessible in both branches
|
||||
// Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
|
||||
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) {
|
||||
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_params.tmpls != nullptr) {
|
||||
// Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
|
||||
json body_json;
|
||||
json messages_json = json::array();
|
||||
@@ -1180,12 +1261,7 @@ public:
|
||||
// Use the same approach as server.cpp: call oaicompat_chat_params_parse
|
||||
// This handles all template application, grammar merging, etc. automatically
|
||||
// Files extracted from multimodal content in messages will be added to the files vector
|
||||
// Create parser options with current chat_templates to ensure tmpls is not null
|
||||
oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt;
|
||||
parser_opt.tmpls = ctx_server.impl->chat_templates.get(); // Ensure tmpls is set to current chat_templates
|
||||
// Update allow_image and allow_audio based on current mctx state
|
||||
parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false;
|
||||
parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false;
|
||||
// chat_params already contains tmpls, allow_image, and allow_audio set during model loading
|
||||
|
||||
// Debug: Log tools before template processing
|
||||
if (body_json.contains("tools")) {
|
||||
@@ -1231,7 +1307,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
|
||||
json parsed_data = oaicompat_chat_params_parse(body_json, ctx_server.impl->chat_params, files);
|
||||
|
||||
// Debug: Log tools after template processing
|
||||
if (parsed_data.contains("tools")) {
|
||||
@@ -1284,7 +1360,7 @@ public:
|
||||
|
||||
// If not using chat templates, extract files from image_data/audio_data fields
|
||||
// (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
|
||||
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) {
|
||||
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_params.tmpls == nullptr) {
|
||||
const auto &images_data = data.find("image_data");
|
||||
if (images_data != data.end() && images_data->is_array())
|
||||
{
|
||||
@@ -1492,7 +1568,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -1512,7 +1588,7 @@ public:
|
||||
std::string prompt_str;
|
||||
std::vector<raw_buffer> files; // Declare files early so it's accessible in both branches
|
||||
// Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
|
||||
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) {
|
||||
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_params.tmpls != nullptr) {
|
||||
// Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
|
||||
json body_json;
|
||||
json messages_json = json::array();
|
||||
@@ -1916,12 +1992,7 @@ public:
|
||||
// Use the same approach as server.cpp: call oaicompat_chat_params_parse
|
||||
// This handles all template application, grammar merging, etc. automatically
|
||||
// Files extracted from multimodal content in messages will be added to the files vector
|
||||
// Create parser options with current chat_templates to ensure tmpls is not null
|
||||
oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt;
|
||||
parser_opt.tmpls = ctx_server.impl->chat_templates.get(); // Ensure tmpls is set to current chat_templates
|
||||
// Update allow_image and allow_audio based on current mctx state
|
||||
parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false;
|
||||
parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false;
|
||||
// chat_params already contains tmpls, allow_image, and allow_audio set during model loading
|
||||
|
||||
// Debug: Log tools before template processing
|
||||
if (body_json.contains("tools")) {
|
||||
@@ -1967,7 +2038,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
|
||||
json parsed_data = oaicompat_chat_params_parse(body_json, ctx_server.impl->chat_params, files);
|
||||
|
||||
// Debug: Log tools after template processing
|
||||
if (parsed_data.contains("tools")) {
|
||||
@@ -2020,7 +2091,7 @@ public:
|
||||
|
||||
// If not using chat templates, extract files from image_data/audio_data fields
|
||||
// (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
|
||||
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) {
|
||||
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_params.tmpls == nullptr) {
|
||||
const auto &images_data = data.find("image_data");
|
||||
if (images_data != data.end() && images_data->is_array())
|
||||
{
|
||||
@@ -2163,7 +2234,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2258,7 +2329,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) override {
|
||||
if (!params_base.embedding || params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
|
||||
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
|
||||
}
|
||||
@@ -2344,7 +2415,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2367,7 +2438,7 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) {
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
@@ -2405,6 +2476,47 @@ public:
|
||||
response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total);
|
||||
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status ModelMetadata(ServerContext* /*context*/, const backend::ModelOptions* /*request*/, backend::ModelMetadataResponse* response) override {
|
||||
// Check if model is loaded
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
|
||||
// Check if chat templates are initialized
|
||||
if (ctx_server.impl->chat_params.tmpls == nullptr) {
|
||||
// If templates are not initialized, we can't detect thinking support
|
||||
// Return false as default
|
||||
response->set_supports_thinking(false);
|
||||
response->set_rendered_template("");
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Detect thinking support using llama.cpp's function
|
||||
bool supports_thinking = common_chat_templates_support_enable_thinking(ctx_server.impl->chat_params.tmpls.get());
|
||||
response->set_supports_thinking(supports_thinking);
|
||||
|
||||
// Render the template with enable_thinking=true so Go code can detect thinking tokens
|
||||
// This allows reusing existing detection functions in Go
|
||||
std::string rendered_template = "";
|
||||
if (params_base.use_jinja) {
|
||||
// Render the template with enable_thinking=true to see what the actual prompt looks like
|
||||
common_chat_templates_inputs dummy_inputs;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "test";
|
||||
dummy_inputs.messages = {msg};
|
||||
dummy_inputs.enable_thinking = true;
|
||||
dummy_inputs.use_jinja = params_base.use_jinja;
|
||||
|
||||
const auto rendered = common_chat_templates_apply(ctx_server.impl->chat_params.tmpls.get(), dummy_inputs);
|
||||
rendered_template = rendered.prompt;
|
||||
}
|
||||
|
||||
response->set_rendered_template(rendered_template);
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=9be0b91927dfa4007d053df72dea7302990226bb
|
||||
STABLEDIFFUSION_GGML_VERSION?=5e4579c11d0678f9765463582d024e58270faa9c
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=679bdb53dbcbfb3e42685f50c7ff367949fd4d48
|
||||
WHISPER_CPP_VERSION?=7aa8818647303b567c3a21fe4220b2681988e220
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -428,6 +428,50 @@
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice"
|
||||
icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
|
||||
- &qwen-tts
|
||||
urls:
|
||||
- https://github.com/QwenLM/Qwen3-TTS
|
||||
description: |
|
||||
Qwen3-TTS is a high-quality text-to-speech model supporting custom voice, voice design, and voice cloning.
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
license: apache-2.0
|
||||
name: "qwen-tts"
|
||||
alias: "qwen-tts"
|
||||
capabilities:
|
||||
nvidia: "cuda12-qwen-tts"
|
||||
intel: "intel-qwen-tts"
|
||||
amd: "rocm-qwen-tts"
|
||||
nvidia-l4t: "nvidia-l4t-qwen-tts"
|
||||
default: "cpu-qwen-tts"
|
||||
nvidia-cuda-13: "cuda13-qwen-tts"
|
||||
nvidia-cuda-12: "cuda12-qwen-tts"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts"
|
||||
icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
|
||||
- &pocket-tts
|
||||
urls:
|
||||
- https://github.com/kyutai-labs/pocket-tts
|
||||
description: |
|
||||
Pocket TTS is a lightweight text-to-speech model designed to run efficiently on CPUs.
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
license: mit
|
||||
name: "pocket-tts"
|
||||
alias: "pocket-tts"
|
||||
capabilities:
|
||||
nvidia: "cuda12-pocket-tts"
|
||||
intel: "intel-pocket-tts"
|
||||
amd: "rocm-pocket-tts"
|
||||
nvidia-l4t: "nvidia-l4t-pocket-tts"
|
||||
default: "cpu-pocket-tts"
|
||||
nvidia-cuda-13: "cuda13-pocket-tts"
|
||||
nvidia-cuda-12: "cuda12-pocket-tts"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts"
|
||||
icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
|
||||
- &piper
|
||||
name: "piper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
|
||||
@@ -515,18 +559,14 @@
|
||||
default: "cpu-neutts"
|
||||
nvidia: "cuda12-neutts"
|
||||
amd: "rocm-neutts"
|
||||
nvidia-l4t: "nvidia-l4t-neutts"
|
||||
nvidia-cuda-12: "cuda12-neutts"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-neutts"
|
||||
- !!merge <<: *neutts
|
||||
name: "neutts-development"
|
||||
capabilities:
|
||||
default: "cpu-neutts-development"
|
||||
nvidia: "cuda12-neutts-development"
|
||||
amd: "rocm-neutts-development"
|
||||
nvidia-l4t: "nvidia-l4t-neutts-development"
|
||||
nvidia-cuda-12: "cuda12-neutts-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-neutts-development"
|
||||
- !!merge <<: *llamacpp
|
||||
name: "llama-cpp-development"
|
||||
capabilities:
|
||||
@@ -556,11 +596,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "nvidia-l4t-arm64-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "cpu-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-neutts"
|
||||
@@ -576,11 +611,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "nvidia-l4t-arm64-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-neutts
|
||||
- !!merge <<: *mlx
|
||||
name: "mlx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
|
||||
@@ -1605,3 +1635,169 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice
|
||||
## qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "qwen-tts-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-qwen-tts-development"
|
||||
intel: "intel-qwen-tts-development"
|
||||
amd: "rocm-qwen-tts-development"
|
||||
nvidia-l4t: "nvidia-l4t-qwen-tts-development"
|
||||
default: "cpu-qwen-tts-development"
|
||||
nvidia-cuda-13: "cuda13-qwen-tts-development"
|
||||
nvidia-cuda-12: "cuda12-qwen-tts-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts-development"
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cpu-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cpu-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda12-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda12-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda13-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda13-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "intel-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "intel-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "rocm-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "rocm-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "nvidia-l4t-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "nvidia-l4t-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-qwen-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-qwen-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts
|
||||
## pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "pocket-tts-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-pocket-tts-development"
|
||||
intel: "intel-pocket-tts-development"
|
||||
amd: "rocm-pocket-tts-development"
|
||||
nvidia-l4t: "nvidia-l4t-pocket-tts-development"
|
||||
default: "cpu-pocket-tts-development"
|
||||
nvidia-cuda-13: "cuda13-pocket-tts-development"
|
||||
nvidia-cuda-12: "cuda12-pocket-tts-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts-development"
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cpu-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cpu-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda12-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda12-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda13-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda13-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "intel-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "intel-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "rocm-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "rocm-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "nvidia-l4t-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "nvidia-l4t-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-pocket-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts
|
||||
- !!merge <<: *pocket-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-pocket-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.8.10+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers
|
||||
|
||||
@@ -15,14 +15,11 @@ fi
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
# This is here because the jetson-ai-lab.io PyPI mirror's root PyPI endpoint (pypi.jetson-ai-lab.io/root/pypi/)
|
||||
# returns 503 errors when uv tries to fall back to it for packages not found in the specific subdirectory.
|
||||
# We need uv to continue falling through to the official PyPI index when it encounters these errors.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
|
||||
installRequirements
|
||||
|
||||
5
backend/python/chatterbox/requirements-install.txt
Normal file
5
backend/python/chatterbox/requirements-install.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
# Build dependencies needed for packages installed from source (e.g., git dependencies)
|
||||
# When using --no-build-isolation, these must be installed in the venv first
|
||||
wheel
|
||||
setuptools
|
||||
packaging
|
||||
@@ -1,7 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
|
||||
@@ -398,7 +398,7 @@ function runProtogen() {
|
||||
# NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index.
|
||||
# you may want to add the following line to a requirements-intel.txt if you use one:
|
||||
#
|
||||
# --index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
# --index-url https://download.pytorch.org/whl/xpu
|
||||
#
|
||||
# If you need to add extra flags into the pip install command you can do so by setting the variable EXTRA_PIP_INSTALL_FLAGS
|
||||
# before calling installRequirements. For example:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.8.10+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0+xpu
|
||||
torchaudio==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
|
||||
@@ -41,6 +41,10 @@ from optimum.quanto import freeze, qfloat8, quantize
|
||||
from transformers import T5EncoderModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# Import LTX-2 specific utilities
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video
|
||||
from diffusers import LTX2VideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
||||
XPU = os.environ.get("XPU", "0") == "1"
|
||||
@@ -290,6 +294,104 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
pipe.enable_model_cpu_offload()
|
||||
return pipe
|
||||
|
||||
# LTX2ImageToVideoPipeline - needs img2vid flag, CPU offload, and special handling
|
||||
if pipeline_type == "LTX2ImageToVideoPipeline":
|
||||
self.img2vid = True
|
||||
self.ltx2_pipeline = True
|
||||
|
||||
# Check if loading from single file (GGUF)
|
||||
if fromSingleFile and LTX2VideoTransformer3DModel is not None:
|
||||
_, single_file_ext = os.path.splitext(modelFile)
|
||||
if single_file_ext == ".gguf":
|
||||
# Load transformer from single GGUF file with quantization
|
||||
transformer_kwargs = {}
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=torchType)
|
||||
transformer_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_single_file(
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
# Load pipeline with custom transformer
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2ImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2ImageToVideoPipeline",
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2ImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return pipe
|
||||
|
||||
# LTX2Pipeline - text-to-video pipeline, needs txt2vid flag, CPU offload, and special handling
|
||||
if pipeline_type == "LTX2Pipeline":
|
||||
self.txt2vid = True
|
||||
self.ltx2_pipeline = True
|
||||
|
||||
# Check if loading from single file (GGUF)
|
||||
if fromSingleFile and LTX2VideoTransformer3DModel is not None:
|
||||
_, single_file_ext = os.path.splitext(modelFile)
|
||||
if single_file_ext == ".gguf":
|
||||
# Load transformer from single GGUF file with quantization
|
||||
transformer_kwargs = {}
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=torchType)
|
||||
transformer_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_single_file(
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
# Load pipeline with custom transformer
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2Pipeline",
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2Pipeline",
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="LTX2Pipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return pipe
|
||||
|
||||
# ================================================================
|
||||
# Dynamic pipeline loading - the default path for most pipelines
|
||||
# Uses the dynamic loader to instantiate any pipeline by class name
|
||||
@@ -404,6 +506,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
self.img2vid = False
|
||||
self.txt2vid = False
|
||||
self.ltx2_pipeline = False
|
||||
|
||||
print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr)
|
||||
|
||||
# Load pipeline using dynamic loader
|
||||
# Special cases that require custom initialization are handled first
|
||||
@@ -414,6 +519,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
torchType=torchType,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr)
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
self.clip_skip = request.CLIPSkip
|
||||
@@ -651,14 +758,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
prompt = request.prompt
|
||||
if not prompt:
|
||||
print(f"GenerateVideo: No prompt provided for video generation.", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
|
||||
|
||||
# Debug: Print raw request values
|
||||
print(f"GenerateVideo: Raw request values - num_frames: {request.num_frames}, fps: {request.fps}, cfg_scale: {request.cfg_scale}, step: {request.step}", file=sys.stderr)
|
||||
|
||||
# Set default values from request or use defaults
|
||||
num_frames = request.num_frames if request.num_frames > 0 else 81
|
||||
fps = request.fps if request.fps > 0 else 16
|
||||
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
||||
num_inference_steps = request.step if request.step > 0 else 40
|
||||
|
||||
print(f"GenerateVideo: Using values - num_frames: {num_frames}, fps: {fps}, cfg_scale: {cfg_scale}, num_inference_steps: {num_inference_steps}", file=sys.stderr)
|
||||
|
||||
# Prepare generation parameters
|
||||
kwargs = {
|
||||
"prompt": prompt,
|
||||
@@ -684,9 +797,86 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
kwargs["end_image"] = load_image(request.end_image)
|
||||
|
||||
print(f"Generating video with {kwargs=}", file=sys.stderr)
|
||||
print(f"GenerateVideo: Pipeline type: {self.PipelineType}, ltx2_pipeline flag: {self.ltx2_pipeline}", file=sys.stderr)
|
||||
|
||||
# Generate video frames based on pipeline type
|
||||
if self.PipelineType == "WanPipeline":
|
||||
if self.ltx2_pipeline or self.PipelineType in ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]:
|
||||
# LTX-2 generation with audio (supports both text-to-video and image-to-video)
|
||||
# Determine if this is text-to-video (no image) or image-to-video (has image)
|
||||
has_image = bool(request.start_image)
|
||||
|
||||
# Remove image-related parameters that might have been added earlier
|
||||
kwargs.pop("start_image", None)
|
||||
kwargs.pop("end_image", None)
|
||||
|
||||
# LTX2ImageToVideoPipeline uses 'image' parameter for image-to-video
|
||||
# LTX2Pipeline (text-to-video) doesn't need an image parameter
|
||||
if has_image:
|
||||
# Image-to-video: use 'image' parameter
|
||||
if self.PipelineType == "LTX2ImageToVideoPipeline":
|
||||
image = load_image(request.start_image)
|
||||
kwargs["image"] = image
|
||||
print(f"LTX-2: Using image-to-video mode with image", file=sys.stderr)
|
||||
else:
|
||||
# If pipeline type is LTX2Pipeline but we have an image, we can't do image-to-video
|
||||
return backend_pb2.Result(success=False, message="LTX2Pipeline does not support image-to-video. Use LTX2ImageToVideoPipeline for image-to-video generation.")
|
||||
else:
|
||||
# Text-to-video: no image parameter needed
|
||||
# Ensure no image-related kwargs are present
|
||||
kwargs.pop("image", None)
|
||||
print(f"LTX-2: Using text-to-video mode (no image)", file=sys.stderr)
|
||||
|
||||
# LTX-2 uses 'frame_rate' instead of 'fps'
|
||||
frame_rate = float(fps)
|
||||
kwargs["frame_rate"] = frame_rate
|
||||
|
||||
# LTX-2 requires output_type="np" and return_dict=False
|
||||
kwargs["output_type"] = "np"
|
||||
kwargs["return_dict"] = False
|
||||
|
||||
# Generate video and audio
|
||||
print(f"LTX-2: Generating with kwargs: {kwargs}", file=sys.stderr)
|
||||
try:
|
||||
video, audio = self.pipe(**kwargs)
|
||||
print(f"LTX-2: Generated video shape: {video.shape}, audio shape: {audio.shape}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"LTX-2: Error during pipe() call: {e}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating video with LTX-2 pipeline: {e}")
|
||||
|
||||
# Convert video to uint8 format
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
print(f"LTX-2: Converting video, shape after conversion: {video.shape}", file=sys.stderr)
|
||||
print(f"LTX-2: Audio sample rate: {self.pipe.vocoder.config.output_sampling_rate}", file=sys.stderr)
|
||||
print(f"LTX-2: Output path: {request.dst}", file=sys.stderr)
|
||||
|
||||
# Use LTX-2's encode_video function which handles audio
|
||||
try:
|
||||
ltx2_encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate,
|
||||
output_path=request.dst,
|
||||
)
|
||||
# Verify file was created and has content
|
||||
import os
|
||||
if os.path.exists(request.dst):
|
||||
file_size = os.path.getsize(request.dst)
|
||||
print(f"LTX-2: Video file created successfully, size: {file_size} bytes", file=sys.stderr)
|
||||
if file_size == 0:
|
||||
return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes). Check LTX-2 encode_video function.")
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}")
|
||||
except Exception as e:
|
||||
print(f"LTX-2: Error encoding video: {e}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error encoding video: {e}")
|
||||
|
||||
return backend_pb2.Result(message="Video generated successfully", success=True)
|
||||
elif self.PipelineType == "WanPipeline":
|
||||
# WAN2.2 text-to-video generation
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0] # WAN2.2 returns frames in this format
|
||||
@@ -725,11 +915,23 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
else:
|
||||
print(f"GenerateVideo: Pipeline {self.PipelineType} does not match any known video pipeline handler", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
|
||||
|
||||
# Export video
|
||||
# Export video (for non-LTX-2 pipelines)
|
||||
print(f"GenerateVideo: Exporting video to {request.dst} with fps={fps}", file=sys.stderr)
|
||||
export_to_video(frames, request.dst, fps=fps)
|
||||
|
||||
# Verify file was created
|
||||
import os
|
||||
if os.path.exists(request.dst):
|
||||
file_size = os.path.getsize(request.dst)
|
||||
print(f"GenerateVideo: Video file created, size: {file_size} bytes", file=sys.stderr)
|
||||
if file_size == 0:
|
||||
return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes)")
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}")
|
||||
|
||||
return backend_pb2.Result(message="Video generated successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
|
||||
@@ -16,6 +16,10 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
# Use python 3.12 for l4t
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
@@ -23,11 +27,4 @@ if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
# This is here because the jetson-ai-lab.io PyPI mirror's root PyPI endpoint (pypi.jetson-ai-lab.io/root/pypi/)
|
||||
# returns 503 errors when uv tries to fall back to it for packages not found in the specific subdirectory.
|
||||
# We need uv to continue falling through to the official PyPI index when it encounters these errors.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
torchvision==0.20.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchvision
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
git+https://github.com/huggingface/diffusers
|
||||
|
||||
@@ -3,3 +3,4 @@ grpcio==1.76.0
|
||||
pillow
|
||||
protobuf
|
||||
certifi
|
||||
av
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
optimum[openvino]
|
||||
faster-whisper
|
||||
@@ -16,11 +16,8 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
# This is here because the jetson-ai-lab.io PyPI mirror's root PyPI endpoint (pypi.jetson-ai-lab.io/root/pypi/)
|
||||
# returns 503 errors when uv tries to fall back to it for packages not found in the specific subdirectory.
|
||||
# We need uv to continue falling through to the official PyPI index when it encounters these errors.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-first-match"
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.8.10+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
torchaudio==2.5.1+cxx11.abi
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
|
||||
@@ -16,13 +16,6 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
# This is here because the jetson-ai-lab.io PyPI mirror's root PyPI endpoint (pypi.jetson-ai-lab.io/root/pypi/)
|
||||
# returns 503 errors when uv tries to fall back to it for packages not found in the specific subdirectory.
|
||||
# We need uv to continue falling through to the official PyPI index when it encounters these errors.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "xl4t" ]; then
|
||||
export CMAKE_ARGS="-DGGML_CUDA=on"
|
||||
fi
|
||||
@@ -33,6 +26,12 @@ fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
|
||||
git clone https://github.com/neuphonic/neutts-air neutts-air
|
||||
|
||||
cp -rfv neutts-air/neuttsair ./
|
||||
|
||||
23
backend/python/pocket-tts/Makefile
Normal file
23
backend/python/pocket-tts/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: pocket-tts
|
||||
pocket-tts:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: pocket-tts
|
||||
@echo "Running pocket-tts..."
|
||||
bash run.sh
|
||||
@echo "pocket-tts run."
|
||||
|
||||
.PHONY: test
|
||||
test: pocket-tts
|
||||
@echo "Testing pocket-tts..."
|
||||
bash test.sh
|
||||
@echo "pocket-tts tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
255
backend/python/pocket-tts/backend.py
Normal file
255
backend/python/pocket-tts/backend.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Pocket TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
import scipy.io.wavfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
from pocket_tts import TTSModel
|
||||
|
||||
import grpc
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
# Normalize potential 'mpx' typo to 'mps'
|
||||
if device == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
|
||||
device = "mps"
|
||||
|
||||
# Validate mps availability if requested
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
self.device = device
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the audio
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# Default voice for caching
|
||||
self.default_voice_url = self.options.get("default_voice", None)
|
||||
self._voice_cache = {}
|
||||
|
||||
try:
|
||||
print("Loading Pocket TTS model", file=sys.stderr)
|
||||
self.tts_model = TTSModel.load_model()
|
||||
print(f"Model loaded successfully. Sample rate: {self.tts_model.sample_rate}", file=sys.stderr)
|
||||
|
||||
# Pre-load default voice if specified
|
||||
if self.default_voice_url:
|
||||
try:
|
||||
print(f"Pre-loading default voice: {self.default_voice_url}", file=sys.stderr)
|
||||
voice_state = self.tts_model.get_state_for_audio_prompt(self.default_voice_url)
|
||||
self._voice_cache[self.default_voice_url] = voice_state
|
||||
print("Default voice loaded successfully", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to pre-load default voice: {e}", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _get_voice_state(self, voice_input):
|
||||
"""
|
||||
Get voice state from cache or load it.
|
||||
voice_input can be:
|
||||
- HuggingFace URL (e.g., hf://kyutai/tts-voices/alba-mackenna/casual.wav)
|
||||
- Local file path
|
||||
- None (use default)
|
||||
"""
|
||||
# Use default if no voice specified
|
||||
if not voice_input:
|
||||
voice_input = self.default_voice_url
|
||||
|
||||
if not voice_input:
|
||||
return None
|
||||
|
||||
# Check cache first
|
||||
if voice_input in self._voice_cache:
|
||||
return self._voice_cache[voice_input]
|
||||
|
||||
# Load voice state
|
||||
try:
|
||||
print(f"Loading voice from: {voice_input}", file=sys.stderr)
|
||||
voice_state = self.tts_model.get_state_for_audio_prompt(voice_input)
|
||||
self._voice_cache[voice_input] = voice_state
|
||||
return voice_state
|
||||
except Exception as e:
|
||||
print(f"Error loading voice from {voice_input}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Determine voice input
|
||||
# Priority: request.voice > AudioPath (from ModelOptions) > default
|
||||
voice_input = None
|
||||
|
||||
if request.voice:
|
||||
voice_input = request.voice
|
||||
elif hasattr(request, 'AudioPath') and request.AudioPath:
|
||||
# Use AudioPath as voice file
|
||||
if os.path.isabs(request.AudioPath):
|
||||
voice_input = request.AudioPath
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
voice_input = os.path.join(model_file_base, request.AudioPath)
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
voice_input = os.path.join(request.ModelPath, request.AudioPath)
|
||||
else:
|
||||
voice_input = request.AudioPath
|
||||
|
||||
# Get voice state
|
||||
voice_state = self._get_voice_state(voice_input)
|
||||
if voice_state is None:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"Voice not found or failed to load: {voice_input}. Please provide a valid voice URL or file path."
|
||||
)
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
|
||||
if not text:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Text is empty"
|
||||
)
|
||||
|
||||
print(f"Generating audio for text: {text[:50]}...", file=sys.stderr)
|
||||
|
||||
# Generate audio
|
||||
audio = self.tts_model.generate_audio(voice_state, text)
|
||||
|
||||
# Audio is a 1D torch tensor containing PCM data
|
||||
if audio is None or audio.numel() == 0:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="No audio generated"
|
||||
)
|
||||
|
||||
# Save audio to file
|
||||
output_path = request.dst
|
||||
if not output_path:
|
||||
output_path = "/tmp/pocket-tts-output.wav"
|
||||
|
||||
# Ensure output directory exists
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Convert torch tensor to numpy and save
|
||||
audio_numpy = audio.numpy()
|
||||
scipy.io.wavfile.write(output_path, self.tts_model.sample_rate, audio_numpy)
|
||||
print(f"Saved audio to {output_path}", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
30
backend/python/pocket-tts/install.sh
Executable file
30
backend/python/pocket-tts/install.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
# Use python 3.12 for l4t
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
11
backend/python/pocket-tts/protogen.sh
Executable file
11
backend/python/pocket-tts/protogen.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
|
||||
4
backend/python/pocket-tts/requirements-cpu.txt
Normal file
4
backend/python/pocket-tts/requirements-cpu.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-cublas12.txt
Normal file
4
backend/python/pocket-tts/requirements-cublas12.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-cublas13.txt
Normal file
4
backend/python/pocket-tts/requirements-cublas13.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-hipblas.txt
Normal file
4
backend/python/pocket-tts/requirements-hipblas.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
pocket-tts
|
||||
scipy
|
||||
torch==2.7.1+rocm6.3
|
||||
4
backend/python/pocket-tts/requirements-intel.txt
Normal file
4
backend/python/pocket-tts/requirements-intel.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-l4t12.txt
Normal file
4
backend/python/pocket-tts/requirements-l4t12.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-l4t13.txt
Normal file
4
backend/python/pocket-tts/requirements-l4t13.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
pocket-tts
|
||||
scipy
|
||||
torch
|
||||
4
backend/python/pocket-tts/requirements-mps.txt
Normal file
4
backend/python/pocket-tts/requirements-mps.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pocket-tts
|
||||
scipy
|
||||
torch==2.7.1
|
||||
torchvision==0.22.1
|
||||
4
backend/python/pocket-tts/requirements.txt
Normal file
4
backend/python/pocket-tts/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
9
backend/python/pocket-tts/run.sh
Executable file
9
backend/python/pocket-tts/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
141
backend/python/pocket-tts/test.py
Normal file
141
backend/python/pocket-tts/test.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_with_hf_voice(self):
|
||||
"""
|
||||
This method tests TTS generation with HuggingFace voice URL
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
# Load model
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Create temporary output file
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
# Test TTS with HuggingFace voice URL
|
||||
tts_request = backend_pb2.TTSRequest(
|
||||
text="Hello world, this is a test.",
|
||||
dst=output_path,
|
||||
voice="azelma"
|
||||
)
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertTrue(tts_response.success)
|
||||
|
||||
# Verify output file exists and is not empty
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
self.assertGreater(os.path.getsize(output_path), 0)
|
||||
|
||||
# Cleanup
|
||||
os.unlink(output_path)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_with_default_voice(self):
|
||||
"""
|
||||
This method tests TTS generation with default voice (via AudioPath in LoadModel)
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
# Load model with default voice
|
||||
load_request = backend_pb2.ModelOptions(
|
||||
Options=["default_voice:azelma"]
|
||||
)
|
||||
response = stub.LoadModel(load_request)
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Create temporary output file
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
# Test TTS without specifying voice (should use default)
|
||||
tts_request = backend_pb2.TTSRequest(
|
||||
text="Hello world, this is a test.",
|
||||
dst=output_path
|
||||
)
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertTrue(tts_response.success)
|
||||
|
||||
# Verify output file exists and is not empty
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
self.assertGreater(os.path.getsize(output_path), 0)
|
||||
|
||||
# Cleanup
|
||||
os.unlink(output_path)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service with default voice failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
11
backend/python/pocket-tts/test.sh
Executable file
11
backend/python/pocket-tts/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
23
backend/python/qwen-tts/Makefile
Normal file
23
backend/python/qwen-tts/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: qwen-tts
|
||||
qwen-tts:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: qwen-tts
|
||||
@echo "Running qwen-tts..."
|
||||
bash run.sh
|
||||
@echo "qwen-tts run."
|
||||
|
||||
.PHONY: test
|
||||
test: qwen-tts
|
||||
@echo "Testing qwen-tts..."
|
||||
bash test.sh
|
||||
@echo "qwen-tts tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
475
backend/python/qwen-tts/backend.py
Normal file
475
backend/python/qwen-tts/backend.py
Normal file
@@ -0,0 +1,475 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Qwen3-TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
import grpc
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
# Normalize potential 'mpx' typo to 'mps'
|
||||
if device == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
|
||||
device = "mps"
|
||||
|
||||
# Validate mps availability if requested
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
self.device = device
|
||||
self._torch_device = torch.device(device)
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the audio
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# Get model path from request
|
||||
model_path = request.Model
|
||||
if not model_path:
|
||||
model_path = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
|
||||
|
||||
# Determine model type from model path or options
|
||||
self.model_type = self.options.get("model_type", None)
|
||||
if not self.model_type:
|
||||
if "CustomVoice" in model_path:
|
||||
self.model_type = "CustomVoice"
|
||||
elif "VoiceDesign" in model_path:
|
||||
self.model_type = "VoiceDesign"
|
||||
elif "Base" in model_path or "0.6B" in model_path or "1.7B" in model_path:
|
||||
self.model_type = "Base" # VoiceClone model
|
||||
else:
|
||||
# Default to CustomVoice
|
||||
self.model_type = "CustomVoice"
|
||||
|
||||
# Cache for voice clone prompts
|
||||
self._voice_clone_cache = {}
|
||||
|
||||
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
|
||||
# These are used later in TTS for VoiceClone mode
|
||||
self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None
|
||||
self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None
|
||||
self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None
|
||||
|
||||
# Decide dtype & attention implementation
|
||||
if self.device == "mps":
|
||||
load_dtype = torch.float32 # MPS requires float32
|
||||
device_map = None
|
||||
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
|
||||
elif self.device == "cuda":
|
||||
load_dtype = torch.bfloat16
|
||||
device_map = "cuda"
|
||||
attn_impl_primary = "flash_attention_2"
|
||||
else: # cpu
|
||||
load_dtype = torch.float32
|
||||
device_map = "cpu"
|
||||
attn_impl_primary = "sdpa"
|
||||
|
||||
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}", file=sys.stderr)
|
||||
print(f"Loading model from: {model_path}", file=sys.stderr)
|
||||
|
||||
# Load model with device-specific logic
|
||||
# Common parameters for all devices
|
||||
load_kwargs = {
|
||||
"dtype": load_dtype,
|
||||
"attn_implementation": attn_impl_primary,
|
||||
"trust_remote_code": True, # Required for qwen-tts models
|
||||
}
|
||||
|
||||
try:
|
||||
if self.device == "mps":
|
||||
load_kwargs["device_map"] = None # load then move
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
self.model.to("mps")
|
||||
elif self.device == "cuda":
|
||||
load_kwargs["device_map"] = device_map
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
else: # cpu
|
||||
load_kwargs["device_map"] = device_map
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"[ERROR] Loading model: {type(e).__name__}: {error_msg}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
# Check if it's a missing feature extractor/tokenizer error
|
||||
if "speech_tokenizer" in error_msg or "preprocessor_config.json" in error_msg or "feature extractor" in error_msg.lower():
|
||||
print("\n[ERROR] Model files appear to be incomplete. This usually means:", file=sys.stderr)
|
||||
print(" 1. The model download was interrupted or incomplete", file=sys.stderr)
|
||||
print(" 2. The model cache is corrupted", file=sys.stderr)
|
||||
print("\nTo fix this, try:", file=sys.stderr)
|
||||
print(f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*", file=sys.stderr)
|
||||
print(" Then re-run to trigger a fresh download.", file=sys.stderr)
|
||||
print("\nAlternatively, try using a different model variant:", file=sys.stderr)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr)
|
||||
|
||||
if attn_impl_primary == 'flash_attention_2':
|
||||
print("\nTrying to use SDPA instead of flash_attention_2...", file=sys.stderr)
|
||||
load_kwargs["attn_implementation"] = 'sdpa'
|
||||
try:
|
||||
if self.device == "mps":
|
||||
load_kwargs["device_map"] = None
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
self.model.to("mps")
|
||||
else:
|
||||
load_kwargs["device_map"] = (self.device if self.device in ("cuda", "cpu") else None)
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
except Exception as e2:
|
||||
print(f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise e2
|
||||
else:
|
||||
raise e
|
||||
|
||||
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
|
||||
# model_type explicitly set
|
||||
if self.model_type == "CustomVoice":
|
||||
return "CustomVoice"
|
||||
if self.model_type == "VoiceClone":
|
||||
return "VoiceClone"
|
||||
if self.model_type == "VoiceDesign":
|
||||
return "VoiceDesign"
|
||||
|
||||
# VoiceClone: AudioPath is provided (from LoadModel, stored in self.audio_path)
|
||||
if self.audio_path:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
return "CustomVoice"
|
||||
|
||||
def _get_ref_audio_path(self, request):
|
||||
"""Get reference audio path from stored AudioPath (from LoadModel)."""
|
||||
if not self.audio_path:
|
||||
return None
|
||||
|
||||
# If absolute path, use as-is
|
||||
if os.path.isabs(self.audio_path):
|
||||
return self.audio_path
|
||||
|
||||
# Try relative to ModelFile
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
# Try relative to ModelPath
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
# Return as-is (might be URL or base64)
|
||||
return self.audio_path
|
||||
|
||||
def _get_voice_clone_prompt(self, request, ref_audio, ref_text):
|
||||
"""Get or create voice clone prompt, with caching."""
|
||||
cache_key = f"{ref_audio}:{ref_text}"
|
||||
|
||||
if cache_key not in self._voice_clone_cache:
|
||||
print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr)
|
||||
try:
|
||||
prompt_items = self.model.create_voice_clone_prompt(
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=self.options.get("x_vector_only_mode", False),
|
||||
)
|
||||
self._voice_clone_cache[cache_key] = prompt_items
|
||||
except Exception as e:
|
||||
print(f"Error creating voice clone prompt: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
return self._voice_clone_cache[cache_key]
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Check if dst is provided
|
||||
if not request.dst:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="dst (output path) is required"
|
||||
)
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
if not text:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Text is empty"
|
||||
)
|
||||
|
||||
# Get language (auto-detect if not provided)
|
||||
language = request.language if hasattr(request, 'language') and request.language else None
|
||||
if not language or language == "":
|
||||
language = "Auto" # Auto-detect language
|
||||
|
||||
# Detect mode
|
||||
mode = self._detect_mode(request)
|
||||
print(f"Detected mode: {mode}", file=sys.stderr)
|
||||
|
||||
# Get generation parameters from options
|
||||
max_new_tokens = self.options.get("max_new_tokens", None)
|
||||
top_p = self.options.get("top_p", None)
|
||||
temperature = self.options.get("temperature", None)
|
||||
do_sample = self.options.get("do_sample", None)
|
||||
|
||||
# Prepare generation kwargs
|
||||
generation_kwargs = {}
|
||||
if max_new_tokens is not None:
|
||||
generation_kwargs["max_new_tokens"] = max_new_tokens
|
||||
if top_p is not None:
|
||||
generation_kwargs["top_p"] = top_p
|
||||
if temperature is not None:
|
||||
generation_kwargs["temperature"] = temperature
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
ref_audio = self._get_ref_audio_path(request)
|
||||
if not ref_audio:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="AudioPath is required for VoiceClone mode"
|
||||
)
|
||||
|
||||
ref_text = self.options.get("ref_text", None)
|
||||
if not ref_text:
|
||||
# Try to get from request if available
|
||||
if hasattr(request, 'ref_text') and request.ref_text:
|
||||
ref_text = request.ref_text
|
||||
else:
|
||||
# x_vector_only_mode doesn't require ref_text
|
||||
if not self.options.get("x_vector_only_mode", False):
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)"
|
||||
)
|
||||
|
||||
# Check if we should use cached prompt
|
||||
use_cached_prompt = self.options.get("use_cached_prompt", True)
|
||||
voice_clone_prompt = None
|
||||
|
||||
if use_cached_prompt:
|
||||
voice_clone_prompt = self._get_voice_clone_prompt(request, ref_audio, ref_text)
|
||||
if voice_clone_prompt is None:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Failed to create voice clone prompt"
|
||||
)
|
||||
|
||||
if voice_clone_prompt:
|
||||
# Use cached prompt
|
||||
wavs, sr = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
voice_clone_prompt=voice_clone_prompt,
|
||||
**generation_kwargs
|
||||
)
|
||||
else:
|
||||
# Create prompt on-the-fly
|
||||
wavs, sr = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=self.options.get("x_vector_only_mode", False),
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
elif mode == "VoiceDesign":
|
||||
# VoiceDesign mode
|
||||
if not instruct:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="instruct option is required for VoiceDesign mode"
|
||||
)
|
||||
|
||||
wavs, sr = self.model.generate_voice_design(
|
||||
text=text,
|
||||
language=language,
|
||||
instruct=instruct,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
# CustomVoice mode (default)
|
||||
speaker = request.voice if request.voice else None
|
||||
if not speaker:
|
||||
# Try to get from options
|
||||
speaker = self.options.get("speaker", None)
|
||||
if not speaker:
|
||||
# Use default speaker
|
||||
speaker = "Vivian"
|
||||
print(f"No speaker specified, using default: {speaker}", file=sys.stderr)
|
||||
|
||||
# Validate speaker if model supports it
|
||||
if hasattr(self.model, 'get_supported_speakers'):
|
||||
try:
|
||||
supported_speakers = self.model.get_supported_speakers()
|
||||
if speaker not in supported_speakers:
|
||||
print(f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}", file=sys.stderr)
|
||||
# Try to find a close match (case-insensitive)
|
||||
speaker_lower = speaker.lower()
|
||||
for sup_speaker in supported_speakers:
|
||||
if sup_speaker.lower() == speaker_lower:
|
||||
speaker = sup_speaker
|
||||
print(f"Using matched speaker: {speaker}", file=sys.stderr)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not get supported speakers: {e}", file=sys.stderr)
|
||||
|
||||
wavs, sr = self.model.generate_custom_voice(
|
||||
text=text,
|
||||
language=language,
|
||||
speaker=speaker,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
# Save output
|
||||
if wavs is not None and len(wavs) > 0:
|
||||
# wavs is a list, take first element
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
sf.write(request.dst, audio_data, sr)
|
||||
print(f"Saved output to {request.dst}", file=sys.stderr)
|
||||
else:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="No audio output generated"
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
13
backend/python/qwen-tts/install.sh
Executable file
13
backend/python/qwen-tts/install.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
5
backend/python/qwen-tts/requirements-cpu.txt
Normal file
5
backend/python/qwen-tts/requirements-cpu.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
1
backend/python/qwen-tts/requirements-cublas12-after.txt
Normal file
1
backend/python/qwen-tts/requirements-cublas12-after.txt
Normal file
@@ -0,0 +1 @@
|
||||
flash-attn
|
||||
5
backend/python/qwen-tts/requirements-cublas12.txt
Normal file
5
backend/python/qwen-tts/requirements-cublas12.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
5
backend/python/qwen-tts/requirements-cublas13.txt
Normal file
5
backend/python/qwen-tts/requirements-cublas13.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
5
backend/python/qwen-tts/requirements-hipblas.txt
Normal file
5
backend/python/qwen-tts/requirements-hipblas.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
torchaudio==2.7.1+rocm6.3
|
||||
qwen-tts
|
||||
sox
|
||||
1
backend/python/qwen-tts/requirements-intel-after.txt
Normal file
1
backend/python/qwen-tts/requirements-intel-after.txt
Normal file
@@ -0,0 +1 @@
|
||||
flash-attn
|
||||
5
backend/python/qwen-tts/requirements-intel.txt
Normal file
5
backend/python/qwen-tts/requirements-intel.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
5
backend/python/qwen-tts/requirements-l4t12.txt
Normal file
5
backend/python/qwen-tts/requirements-l4t12.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
5
backend/python/qwen-tts/requirements-l4t13.txt
Normal file
5
backend/python/qwen-tts/requirements-l4t13.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
qwen-tts
|
||||
sox
|
||||
4
backend/python/qwen-tts/requirements-mps.txt
Normal file
4
backend/python/qwen-tts/requirements-mps.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch==2.7.1
|
||||
torchaudio==0.22.1
|
||||
qwen-tts
|
||||
sox
|
||||
6
backend/python/qwen-tts/requirements.txt
Normal file
6
backend/python/qwen-tts/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
setuptools
|
||||
9
backend/python/qwen-tts/run.sh
Executable file
9
backend/python/qwen-tts/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
98
backend/python/qwen-tts/test.py
Normal file
98
backend/python/qwen-tts/test.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
try:
|
||||
stdout, stderr = self.service.communicate(timeout=5)
|
||||
# Output should already be printed by threads, but print any remaining
|
||||
if stdout:
|
||||
print("=== REMAINING STDOUT ===")
|
||||
print(stdout)
|
||||
if stderr:
|
||||
print("=== REMAINING STDERR ===")
|
||||
print(stderr)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.service.kill()
|
||||
stdout, stderr = self.service.communicate()
|
||||
if stdout:
|
||||
print("=== REMAINING STDOUT ===")
|
||||
print(stdout)
|
||||
if stderr:
|
||||
print("=== REMAINING STDERR ===")
|
||||
print(stderr)
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the TTS generation works successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
# Allow up to 10 minutes for model download on first run
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"),
|
||||
timeout=600.0
|
||||
)
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Create temporary output file
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
tts_request = backend_pb2.TTSRequest(
|
||||
text="Hello, this is a test of the qwen-tts backend.",
|
||||
voice="Vivian",
|
||||
dst=output_path
|
||||
)
|
||||
# Allow up to 2 minutes for TTS generation
|
||||
tts_response = stub.TTS(tts_request, timeout=120.0)
|
||||
self.assertIsNotNone(tts_response)
|
||||
self.assertTrue(tts_response.success)
|
||||
|
||||
# Verify output file exists and is not empty
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
self.assertGreater(os.path.getsize(output_path), 0)
|
||||
|
||||
# Cleanup
|
||||
os.unlink(output_path)
|
||||
except Exception as err:
|
||||
print(f"Exception: {err}", file=sys.stderr)
|
||||
# Give threads a moment to flush any remaining output
|
||||
time.sleep(1)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
11
backend/python/qwen-tts/test.sh
Executable file
11
backend/python/qwen-tts/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -1,9 +1,7 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
transformers
|
||||
accelerate
|
||||
torch==2.3.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
torch
|
||||
rerankers[transformers]
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchvision==0.18.1+cxx11.abi
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchvision
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
rfdetr
|
||||
|
||||
@@ -6,4 +6,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
@@ -6,4 +6,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
@@ -6,4 +6,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
@@ -8,4 +8,4 @@ bitsandbytes
|
||||
outetts
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
@@ -1,13 +1,10 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
intel-extension-for-transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
@@ -1,5 +1,5 @@
|
||||
grpcio==1.76.0
|
||||
protobuf==6.33.2
|
||||
protobuf==6.33.4
|
||||
certifi
|
||||
setuptools
|
||||
scipy==1.15.1
|
||||
|
||||
@@ -23,11 +23,8 @@ if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
# This is here because the jetson-ai-lab.io PyPI mirror's root PyPI endpoint (pypi.jetson-ai-lab.io/root/pypi/)
|
||||
# returns 503 errors when uv tries to fall back to it for packages not found in the specific subdirectory.
|
||||
# We need uv to continue falling through to the official PyPI index when it encounters these errors.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-first-match"
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.5.1+cxx11.abi
|
||||
torchvision==0.20.1+cxx11.abi
|
||||
oneccl_bind_pt==2.8.0+xpu
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchvision
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
git+https://github.com/huggingface/diffusers
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.7.10+xpu
|
||||
accelerate
|
||||
torch==2.7.0+xpu
|
||||
torch
|
||||
transformers
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
bitsandbytes
|
||||
oneccl_bind_pt==2.7.0+xpu
|
||||
bitsandbytes
|
||||
@@ -61,6 +61,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Detect thinking support after model load (only if not already detected)
|
||||
// This needs to happen after LoadModel succeeds so the backend can render templates
|
||||
if (c.ReasoningConfig.DisableReasoning == nil && c.ReasoningConfig.DisableReasoningTagPrefill == nil) && c.TemplateConfig.UseTokenizerTemplate {
|
||||
modelOpts := grpcModelOpts(*c, o.SystemState.Model.ModelsPath)
|
||||
config.DetectThinkingSupportFromBackend(ctx, c, inferenceModel, modelOpts)
|
||||
// Update the config in the loader so it persists for future requests
|
||||
cl.UpdateModelConfig(c.Name, func(cfg *config.ModelConfig) {
|
||||
cfg.ReasoningConfig.DisableReasoning = c.ReasoningConfig.DisableReasoning
|
||||
cfg.ReasoningConfig.DisableReasoningTagPrefill = c.ReasoningConfig.DisableReasoningTagPrefill
|
||||
})
|
||||
}
|
||||
|
||||
var protoMessages []*proto.Message
|
||||
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
||||
// unless the prompt has already been tokenized (non-chat endpoints + functions)
|
||||
|
||||
@@ -56,7 +56,7 @@ type RunCMD struct {
|
||||
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
||||
DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"`
|
||||
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
|
||||
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/image/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
|
||||
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
|
||||
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
|
||||
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
|
||||
@@ -83,6 +83,7 @@ type RunCMD struct {
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
@@ -249,6 +250,15 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithLRUEvictionRetryInterval(dur))
|
||||
}
|
||||
|
||||
// Handle Open Responses store TTL
|
||||
if r.OpenResponsesStoreTTL != "" && r.OpenResponsesStoreTTL != "0" {
|
||||
dur, err := time.ParseDuration(r.OpenResponsesStoreTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Open Responses store TTL: %w", err)
|
||||
}
|
||||
opts = append(opts, config.WithOpenResponsesStoreTTL(dur))
|
||||
}
|
||||
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range r.ExternalGRPCBackends {
|
||||
backend := v[:strings.IndexByte(v, ':')]
|
||||
|
||||
@@ -86,6 +86,8 @@ type ApplicationConfig struct {
|
||||
|
||||
AgentJobRetentionDays int // Default: 30 days
|
||||
|
||||
OpenResponsesStoreTTL time.Duration // TTL for Open Responses store (0 = no expiration)
|
||||
|
||||
PathWithoutAuth []string
|
||||
}
|
||||
|
||||
@@ -467,6 +469,12 @@ func WithAgentJobRetentionDays(days int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenResponsesStoreTTL(ttl time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.OpenResponsesStoreTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
func WithEnforcedPredownloadScans(enforced bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.EnforcePredownloadScans = enforced
|
||||
@@ -594,6 +602,12 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
} else {
|
||||
lruEvictionRetryInterval = "1s" // default
|
||||
}
|
||||
var openResponsesStoreTTL string
|
||||
if o.OpenResponsesStoreTTL > 0 {
|
||||
openResponsesStoreTTL = o.OpenResponsesStoreTTL.String()
|
||||
} else {
|
||||
openResponsesStoreTTL = "0" // default: no expiration
|
||||
}
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
@@ -628,6 +642,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -769,6 +784,14 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.AgentJobRetentionDays != nil {
|
||||
o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
|
||||
}
|
||||
if settings.OpenResponsesStoreTTL != nil {
|
||||
if *settings.OpenResponsesStoreTTL == "0" || *settings.OpenResponsesStoreTTL == "" {
|
||||
o.OpenResponsesStoreTTL = 0 // No expiration
|
||||
} else if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil {
|
||||
o.OpenResponsesStoreTTL = dur
|
||||
}
|
||||
// This setting doesn't require restart, can be updated dynamically
|
||||
}
|
||||
// Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
|
||||
|
||||
return requireRestart
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
"github.com/gpustack/gguf-parser-go/util/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -62,16 +68,25 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.NGPULayers = &defaultHigh
|
||||
}
|
||||
|
||||
xlog.Debug("guessDefaultsFromFile: NGPULayers set", "NGPULayers", cfg.NGPULayers)
|
||||
xlog.Debug("[gguf] guessDefaultsFromFile: NGPULayers set", "NGPULayers", cfg.NGPULayers, "modelName", f.Metadata().Name)
|
||||
|
||||
// identify from well known templates first, otherwise use the raw jinja template
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found {
|
||||
// fill jinja template
|
||||
cfg.modelTemplate = chatTemplate.ValueString()
|
||||
}
|
||||
|
||||
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
||||
|
||||
// template estimations
|
||||
if cfg.HasTemplate() {
|
||||
// nothing to guess here
|
||||
xlog.Debug("guessDefaultsFromFile: template already set", "name", cfg.Name)
|
||||
xlog.Debug("[gguf] guessDefaultsFromFile: template already set", "name", cfg.Name, "modelName", f.Metadata().Name)
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Debug("Model file loaded", "file", cfg.ModelFileName(), "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture)
|
||||
xlog.Debug("[gguf] Model file loaded", "file", cfg.ModelFileName(), "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture)
|
||||
|
||||
// guess the name
|
||||
if cfg.Name == "" {
|
||||
@@ -83,4 +98,49 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
|
||||
}
|
||||
|
||||
// DetectThinkingSupportFromBackend calls the ModelMetadata gRPC method to detect
|
||||
// if the model supports thinking mode and if the template ends with a thinking start token.
|
||||
// This should be called after the model is loaded.
|
||||
// The results are stored in cfg.SupportsThinking and cfg.ThinkingForcedOpen.
|
||||
func DetectThinkingSupportFromBackend(ctx context.Context, cfg *ModelConfig, backendClient grpc.Backend, modelOptions *pb.ModelOptions) {
|
||||
if backendClient == nil {
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: backend client is nil, skipping detection")
|
||||
return
|
||||
}
|
||||
|
||||
if modelOptions == nil {
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: model options is nil, skipping detection")
|
||||
return
|
||||
}
|
||||
|
||||
// Only detect for llama-cpp backend when using tokenizer templates
|
||||
if cfg.Backend != "llama-cpp" || !cfg.TemplateConfig.UseTokenizerTemplate {
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: skipping detection", "backend", cfg.Backend, "useTokenizerTemplate", cfg.TemplateConfig.UseTokenizerTemplate)
|
||||
return
|
||||
}
|
||||
|
||||
metadata, err := backendClient.ModelMetadata(ctx, modelOptions)
|
||||
if err != nil {
|
||||
xlog.Warn("[gguf] DetectThinkingSupportFromBackend: failed to get model metadata", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
cfg.ReasoningConfig.DisableReasoning = ptr.To(!metadata.SupportsThinking)
|
||||
|
||||
// Use the rendered template to detect if thinking token is at the end
|
||||
// This reuses the existing DetectThinkingStartToken function
|
||||
if metadata.RenderedTemplate != "" {
|
||||
thinkingStartToken := reasoning.DetectThinkingStartToken(metadata.RenderedTemplate, &cfg.ReasoningConfig)
|
||||
thinkingForcedOpen := thinkingStartToken != ""
|
||||
cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(!thinkingForcedOpen)
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", thinkingForcedOpen, "thinking_start_token", thinkingStartToken)
|
||||
} else {
|
||||
cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(true)
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/cogito"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -30,6 +31,7 @@ type TTSConfig struct {
|
||||
// @Description ModelConfig represents a model configuration
|
||||
type ModelConfig struct {
|
||||
modelConfigFile string `yaml:"-" json:"-"`
|
||||
modelTemplate string `yaml:"-" json:"-"`
|
||||
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty"`
|
||||
|
||||
@@ -51,6 +53,7 @@ type ModelConfig struct {
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
||||
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
@@ -521,6 +524,11 @@ func (c *ModelConfig) GetModelConfigFile() string {
|
||||
return c.modelConfigFile
|
||||
}
|
||||
|
||||
// GetModelTemplate returns the model's chat template if available
|
||||
func (c *ModelConfig) GetModelTemplate() string {
|
||||
return c.modelTemplate
|
||||
}
|
||||
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
|
||||
@@ -246,6 +246,17 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
|
||||
delete(bcl.configs, m)
|
||||
}
|
||||
|
||||
// UpdateModelConfig updates an existing model config in the loader.
|
||||
// This is useful for updating runtime-detected properties like thinking support.
|
||||
func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
if cfg, exists := bcl.configs[m]; exists {
|
||||
updater(&cfg)
|
||||
bcl.configs[m] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Preload prepare models if they are not local but url or huggingface repositories
|
||||
func (bcl *ModelConfigLoader) Preload(modelPath string) error {
|
||||
bcl.Lock()
|
||||
|
||||
@@ -60,4 +60,7 @@ type RuntimeSettings struct {
|
||||
|
||||
// Agent settings
|
||||
AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"`
|
||||
|
||||
// Open Responses settings
|
||||
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
|
||||
}
|
||||
|
||||
@@ -63,6 +63,25 @@ func (m *GalleryBackend) IsMeta() bool {
|
||||
return len(m.CapabilitiesMap) > 0 && m.URI == ""
|
||||
}
|
||||
|
||||
// IsCompatibleWith checks if the backend is compatible with the current system capability.
|
||||
// For meta backends, it checks if any of the capabilities in the map match the system capability.
|
||||
// For concrete backends, it delegates to SystemState.IsBackendCompatible.
|
||||
func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool {
|
||||
if systemState == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Meta backends are compatible if the system capability matches one of the keys
|
||||
if m.IsMeta() {
|
||||
capability := systemState.Capability(m.CapabilitiesMap)
|
||||
_, exists := m.CapabilitiesMap[capability]
|
||||
return exists
|
||||
}
|
||||
|
||||
// For concrete backends, delegate to the system package
|
||||
return systemState.IsBackendCompatible(m.Name, m.URI)
|
||||
}
|
||||
|
||||
func (m *GalleryBackend) SetInstalled(installed bool) {
|
||||
m.Installed = installed
|
||||
}
|
||||
|
||||
@@ -172,6 +172,252 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(nilMetaBackend.IsMeta()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should check IsCompatibleWith correctly for meta backends", func() {
|
||||
metaBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "meta-backend",
|
||||
},
|
||||
CapabilitiesMap: map[string]string{
|
||||
"nvidia": "nvidia-backend",
|
||||
"amd": "amd-backend",
|
||||
"default": "default-backend",
|
||||
},
|
||||
}
|
||||
|
||||
// Test with nil state - should be compatible
|
||||
Expect(metaBackend.IsCompatibleWith(nil)).To(BeTrue())
|
||||
|
||||
// Test with NVIDIA system - should be compatible (has nvidia key)
|
||||
nvidiaState := &system.SystemState{GPUVendor: "nvidia", VRAM: 8 * 1024 * 1024 * 1024}
|
||||
Expect(metaBackend.IsCompatibleWith(nvidiaState)).To(BeTrue())
|
||||
|
||||
// Test with default (no GPU) - should be compatible (has default key)
|
||||
defaultState := &system.SystemState{}
|
||||
Expect(metaBackend.IsCompatibleWith(defaultState)).To(BeTrue())
|
||||
})
|
||||
|
||||
Describe("IsCompatibleWith for concrete backends", func() {
|
||||
Context("CPU backends", func() {
|
||||
It("should be compatible on all systems", func() {
|
||||
cpuBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "cpu-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp",
|
||||
}
|
||||
Expect(cpuBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
|
||||
Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Darwin/Metal backends", func() {
|
||||
When("running on darwin", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS != "darwin" {
|
||||
Skip("Skipping darwin-specific tests on non-darwin system")
|
||||
}
|
||||
})
|
||||
|
||||
It("should be compatible for MLX backend", func() {
|
||||
mlxBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "mlx",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx",
|
||||
}
|
||||
Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should be compatible for metal-llama-cpp backend", func() {
|
||||
metalBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "metal-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp",
|
||||
}
|
||||
Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("running on non-darwin", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("Skipping non-darwin-specific tests on darwin system")
|
||||
}
|
||||
})
|
||||
|
||||
It("should NOT be compatible for MLX backend", func() {
|
||||
mlxBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "mlx",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx",
|
||||
}
|
||||
Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should NOT be compatible for metal-llama-cpp backend", func() {
|
||||
metalBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "metal-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp",
|
||||
}
|
||||
Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("NVIDIA/CUDA backends", func() {
|
||||
When("running on non-darwin", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("Skipping CUDA tests on darwin system")
|
||||
}
|
||||
})
|
||||
|
||||
It("should NOT be compatible without nvidia GPU", func() {
|
||||
cudaBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "cuda12-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp",
|
||||
}
|
||||
Expect(cudaBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
|
||||
Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should be compatible with nvidia GPU", func() {
|
||||
cudaBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "cuda12-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp",
|
||||
}
|
||||
Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should be compatible with cuda13 backend on nvidia GPU", func() {
|
||||
cuda13Backend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "cuda13-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp",
|
||||
}
|
||||
Expect(cuda13Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("AMD/ROCm backends", func() {
|
||||
When("running on non-darwin", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("Skipping AMD/ROCm tests on darwin system")
|
||||
}
|
||||
})
|
||||
|
||||
It("should NOT be compatible without AMD GPU", func() {
|
||||
rocmBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "rocm-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp",
|
||||
}
|
||||
Expect(rocmBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
|
||||
Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should be compatible with AMD GPU", func() {
|
||||
rocmBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "rocm-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp",
|
||||
}
|
||||
Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should be compatible with hipblas backend on AMD GPU", func() {
|
||||
hipBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "hip-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-hip-llama-cpp",
|
||||
}
|
||||
Expect(hipBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Intel/SYCL backends", func() {
|
||||
When("running on non-darwin", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("Skipping Intel/SYCL tests on darwin system")
|
||||
}
|
||||
})
|
||||
|
||||
It("should NOT be compatible without Intel GPU", func() {
|
||||
intelBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "intel-sycl-f16-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp",
|
||||
}
|
||||
Expect(intelBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
|
||||
Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should be compatible with Intel GPU", func() {
|
||||
intelBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "intel-sycl-f16-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp",
|
||||
}
|
||||
Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should be compatible with intel-sycl-f32 backend on Intel GPU", func() {
|
||||
intelF32Backend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "intel-sycl-f32-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp",
|
||||
}
|
||||
Expect(intelF32Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should be compatible with intel-transformers backend on Intel GPU", func() {
|
||||
intelTransformersBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "intel-transformers",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-intel-transformers",
|
||||
}
|
||||
Expect(intelTransformersBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Vulkan backends", func() {
|
||||
It("should be compatible on CPU-only systems", func() {
|
||||
// Vulkan backends don't have a specific GPU vendor requirement in the current logic
|
||||
// They are compatible if no other GPU-specific pattern matches
|
||||
vulkanBackend := &GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "vulkan-llama-cpp",
|
||||
},
|
||||
URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp",
|
||||
}
|
||||
// Vulkan doesn't have vendor-specific filtering in current implementation
|
||||
Expect(vulkanBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
It("should find best backend from meta based on system capabilities", func() {
|
||||
|
||||
metaBackend := &GalleryBackend{
|
||||
|
||||
@@ -226,6 +226,16 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst
|
||||
|
||||
// List available backends
|
||||
func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
|
||||
return availableBackendsWithFilter(galleries, systemState, true)
|
||||
}
|
||||
|
||||
// AvailableBackendsUnfiltered returns all available backends without filtering by system capability.
|
||||
func AvailableBackendsUnfiltered(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
|
||||
return availableBackendsWithFilter(galleries, systemState, false)
|
||||
}
|
||||
|
||||
// availableBackendsWithFilter is a helper function that lists available backends with optional filtering.
|
||||
func availableBackendsWithFilter(galleries []config.Gallery, systemState *system.SystemState, filterByCapability bool) (GalleryElements[*GalleryBackend], error) {
|
||||
var backends []*GalleryBackend
|
||||
|
||||
systemBackends, err := ListSystemBackends(systemState)
|
||||
@@ -241,7 +251,17 @@ func AvailableBackends(galleries []config.Gallery, systemState *system.SystemSta
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
backends = append(backends, galleryBackends...)
|
||||
|
||||
// Filter backends by system capability if requested
|
||||
if filterByCapability {
|
||||
for _, backend := range galleryBackends {
|
||||
if backend.IsCompatibleWith(systemState) {
|
||||
backends = append(backends, backend)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
backends = append(backends, galleryBackends...)
|
||||
}
|
||||
}
|
||||
|
||||
return backends, nil
|
||||
|
||||
@@ -108,7 +108,15 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
err := next(c)
|
||||
xlog.Info("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
|
||||
|
||||
// Fix for #7989: Reduce log verbosity of Web UI polling
|
||||
// If the path is /api/operations and the request was successful (200),
|
||||
// we log it at DEBUG level (hidden by default) instead of INFO.
|
||||
if req.URL.Path == "/api/operations" && res.Status == 200 {
|
||||
xlog.Debug("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
|
||||
} else {
|
||||
xlog.Info("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
|
||||
}
|
||||
return err
|
||||
}
|
||||
})
|
||||
@@ -185,6 +193,8 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
}
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
} else {
|
||||
e.Use(middleware.CORS())
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
@@ -205,6 +215,8 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application)
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
|
||||
537
core/http/endpoints/anthropic/messages.go
Normal file
537
core/http/endpoints/anthropic/messages.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// MessagesEndpoint is the Anthropic Messages API endpoint
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// @Summary Generate a message response for the given messages and model.
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "model configuration not found")
|
||||
}
|
||||
|
||||
if input.MaxTokens <= 0 {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "max_tokens is required and must be greater than 0")
|
||||
}
|
||||
|
||||
xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg)
|
||||
|
||||
// Convert Anthropic messages to OpenAI format for internal processing
|
||||
openAIMessages := convertAnthropicToOpenAIMessages(input)
|
||||
|
||||
// Convert Anthropic tools to internal Functions format
|
||||
funcs, shouldUseFn := convertAnthropicTools(input, cfg)
|
||||
|
||||
// Create an OpenAI-compatible request for internal processing
|
||||
openAIReq := &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
|
||||
Temperature: input.Temperature,
|
||||
TopK: input.TopK,
|
||||
TopP: input.TopP,
|
||||
Maxtokens: &input.MaxTokens,
|
||||
},
|
||||
Messages: openAIMessages,
|
||||
Stream: input.Stream,
|
||||
Context: input.Context,
|
||||
Cancel: input.Cancel,
|
||||
}
|
||||
|
||||
// Set stop sequences
|
||||
if len(input.StopSequences) > 0 {
|
||||
openAIReq.Stop = input.StopSequences
|
||||
}
|
||||
|
||||
// Merge config settings
|
||||
if input.Temperature != nil {
|
||||
cfg.Temperature = input.Temperature
|
||||
}
|
||||
if input.TopK != nil {
|
||||
cfg.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
cfg.TopP = input.TopP
|
||||
}
|
||||
cfg.Maxtokens = &input.MaxTokens
|
||||
if len(input.StopSequences) > 0 {
|
||||
cfg.StopWords = append(cfg.StopWords, input.StopSequences...)
|
||||
}
|
||||
|
||||
// Template the prompt with tools if available
|
||||
predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
|
||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||
|
||||
if input.Stream {
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn)
|
||||
}
|
||||
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error {
|
||||
images := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, nil, nil, nil, "", "", nil, nil, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
|
||||
result := backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Check if the result contains tool calls
|
||||
toolCalls := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
|
||||
var contentBlocks []schema.AnthropicContentBlock
|
||||
var stopReason string
|
||||
|
||||
if shouldUseFn && len(toolCalls) > 0 {
|
||||
// Model wants to use tools
|
||||
stopReason = "tool_use"
|
||||
for _, tc := range toolCalls {
|
||||
// Parse arguments as JSON
|
||||
var inputArgs map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
|
||||
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
|
||||
inputArgs = map[string]interface{}{"raw": tc.Arguments}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fmt.Sprintf("toolu_%s_%d", id, len(contentBlocks)),
|
||||
Name: tc.Name,
|
||||
Input: inputArgs,
|
||||
})
|
||||
}
|
||||
|
||||
// Add any text content before the tool calls
|
||||
textContent := functions.ParseTextContent(result, cfg.FunctionsConfig)
|
||||
if textContent != "" {
|
||||
// Prepend text block
|
||||
contentBlocks = append([]schema.AnthropicContentBlock{{Type: "text", Text: textContent}}, contentBlocks...)
|
||||
}
|
||||
} else {
|
||||
// Normal text response
|
||||
stopReason = "end_turn"
|
||||
contentBlocks = []schema.AnthropicContentBlock{
|
||||
{Type: "text", Text: result},
|
||||
}
|
||||
}
|
||||
|
||||
resp := &schema.AnthropicResponse{
|
||||
ID: fmt.Sprintf("msg_%s", id),
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: input.Model,
|
||||
StopReason: &stopReason,
|
||||
Content: contentBlocks,
|
||||
Usage: schema.AnthropicUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
if respData, err := json.Marshal(resp); err == nil {
|
||||
xlog.Debug("Anthropic Response", "response", string(respData))
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Create OpenAI messages for inference
|
||||
openAIMessages := openAIReq.Messages
|
||||
|
||||
images := []string{}
|
||||
for _, m := range openAIMessages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// Send message_start event
|
||||
messageStart := schema.AnthropicStreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &schema.AnthropicStreamMessage{
|
||||
ID: fmt.Sprintf("msg_%s", id),
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: []schema.AnthropicContentBlock{},
|
||||
Model: input.Model,
|
||||
Usage: schema.AnthropicUsage{InputTokens: 0, OutputTokens: 0},
|
||||
},
|
||||
}
|
||||
sendAnthropicSSE(c, messageStart)
|
||||
|
||||
// Track accumulated content for tool call detection
|
||||
accumulatedContent := ""
|
||||
currentBlockIndex := 0
|
||||
inToolCall := false
|
||||
toolCallsEmitted := 0
|
||||
|
||||
// Send initial content_block_start event
|
||||
contentBlockStart := schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
|
||||
}
|
||||
sendAnthropicSSE(c, contentBlockStart)
|
||||
|
||||
// Stream content deltas
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
accumulatedContent += token
|
||||
|
||||
// If we're using functions, try to detect tool calls incrementally
|
||||
if shouldUseFn {
|
||||
cleanedResult := functions.CleanupLLMResult(accumulatedContent, cfg.FunctionsConfig)
|
||||
|
||||
// Try parsing for tool calls
|
||||
toolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
|
||||
// If we detected new tool calls and haven't emitted them yet
|
||||
if len(toolCalls) > toolCallsEmitted {
|
||||
// Stop the current text block if we were in one
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
}
|
||||
|
||||
// Emit new tool calls
|
||||
for i := toolCallsEmitted; i < len(toolCalls); i++ {
|
||||
tc := toolCalls[i]
|
||||
|
||||
// Send content_block_start for tool_use
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fmt.Sprintf("toolu_%s_%d", id, i),
|
||||
Name: tc.Name,
|
||||
},
|
||||
})
|
||||
|
||||
// Send input_json_delta with the arguments
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: tc.Arguments,
|
||||
},
|
||||
})
|
||||
|
||||
// Send content_block_stop
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
})
|
||||
|
||||
currentBlockIndex++
|
||||
}
|
||||
toolCallsEmitted = len(toolCalls)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Send regular text delta if not in tool call mode
|
||||
if !inToolCall {
|
||||
delta := schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
},
|
||||
}
|
||||
sendAnthropicSSE(c, delta)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, nil, nil, tokenCallback, "", "", nil, nil, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
|
||||
// Send content_block_stop event for last block if we didn't close it yet
|
||||
if !inToolCall {
|
||||
contentBlockStop := schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
}
|
||||
sendAnthropicSSE(c, contentBlockStop)
|
||||
}
|
||||
|
||||
// Determine stop reason
|
||||
stopReason := "end_turn"
|
||||
if toolCallsEmitted > 0 {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
|
||||
// Send message_delta event with stop_reason
|
||||
messageDelta := schema.AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
StopReason: &stopReason,
|
||||
},
|
||||
Usage: &schema.AnthropicUsage{
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
},
|
||||
}
|
||||
sendAnthropicSSE(c, messageDelta)
|
||||
|
||||
// Send message_stop event
|
||||
messageStop := schema.AnthropicStreamEvent{
|
||||
Type: "message_stop",
|
||||
}
|
||||
sendAnthropicSSE(c, messageStop)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal SSE event", "error", err)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data))
|
||||
c.Response().Flush()
|
||||
}
|
||||
|
||||
func sendAnthropicError(c echo.Context, statusCode int, errorType, message string) error {
|
||||
resp := schema.AnthropicErrorResponse{
|
||||
Type: "error",
|
||||
Error: schema.AnthropicError{
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
return c.JSON(statusCode, resp)
|
||||
}
|
||||
|
||||
func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.Message {
|
||||
var messages []schema.Message
|
||||
|
||||
// Add system message if present
|
||||
if input.System != "" {
|
||||
messages = append(messages, schema.Message{
|
||||
Role: "system",
|
||||
StringContent: input.System,
|
||||
Content: input.System,
|
||||
})
|
||||
}
|
||||
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
for _, msg := range input.Messages {
|
||||
openAIMsg := schema.Message{
|
||||
Role: msg.Role,
|
||||
}
|
||||
|
||||
// Handle content (can be string or array of content blocks)
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
openAIMsg.StringContent = content
|
||||
openAIMsg.Content = content
|
||||
case []interface{}:
|
||||
// Handle array of content blocks
|
||||
var textContent string
|
||||
var stringImages []string
|
||||
var toolCalls []schema.ToolCall
|
||||
toolCallIndex := 0
|
||||
|
||||
for _, block := range content {
|
||||
if blockMap, ok := block.(map[string]interface{}); ok {
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent += text
|
||||
}
|
||||
case "image":
|
||||
// Handle image content
|
||||
if source, ok := blockMap["source"].(map[string]interface{}); ok {
|
||||
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
mediaType, _ := source["media_type"].(string)
|
||||
// Format as data URI
|
||||
dataURI := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
|
||||
stringImages = append(stringImages, dataURI)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
// Convert tool_use to ToolCall format
|
||||
toolID, _ := blockMap["id"].(string)
|
||||
toolName, _ := blockMap["name"].(string)
|
||||
toolInput := blockMap["input"]
|
||||
|
||||
// Serialize input to JSON string
|
||||
inputJSON, err := json.Marshal(toolInput)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to marshal tool input", "error", err)
|
||||
inputJSON = []byte("{}")
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, schema.ToolCall{
|
||||
Index: toolCallIndex,
|
||||
ID: toolID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: toolName,
|
||||
Arguments: string(inputJSON),
|
||||
},
|
||||
})
|
||||
toolCallIndex++
|
||||
case "tool_result":
|
||||
// Convert tool_result to a message with role "tool"
|
||||
// This is handled by creating a separate message after this block
|
||||
// For now, we'll add it as text content
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
isError := false
|
||||
if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil {
|
||||
isError = *isErrorPtr
|
||||
}
|
||||
|
||||
var resultText string
|
||||
if resultContent, ok := blockMap["content"]; ok {
|
||||
switch rc := resultContent.(type) {
|
||||
case string:
|
||||
resultText = rc
|
||||
case []interface{}:
|
||||
// Array of content blocks
|
||||
for _, cb := range rc {
|
||||
if cbMap, ok := cb.(map[string]interface{}); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultText += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool result as a tool role message
|
||||
// We need to handle this differently - create a new message
|
||||
if msg.Role == "user" {
|
||||
// Store tool result info for creating separate message
|
||||
prefix := ""
|
||||
if isError {
|
||||
prefix = "Error: "
|
||||
}
|
||||
textContent += fmt.Sprintf("\n[Tool Result for %s]: %s%s", toolUseID, prefix, resultText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
openAIMsg.StringContent = textContent
|
||||
openAIMsg.Content = textContent
|
||||
openAIMsg.StringImages = stringImages
|
||||
|
||||
// Add tool calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
openAIMsg.ToolCalls = toolCalls
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, openAIMsg)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// convertAnthropicTools converts Anthropic tools to internal Functions format
|
||||
func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConfig) (functions.Functions, bool) {
|
||||
if len(input.Tools) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var funcs functions.Functions
|
||||
for _, tool := range input.Tools {
|
||||
f := functions.Function{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.InputSchema,
|
||||
}
|
||||
funcs = append(funcs, f)
|
||||
}
|
||||
|
||||
// Handle tool_choice
|
||||
if input.ToolChoice != nil {
|
||||
switch tc := input.ToolChoice.(type) {
|
||||
case string:
|
||||
// "auto", "any", or "none"
|
||||
if tc == "any" {
|
||||
// Force the model to use one of the tools
|
||||
cfg.SetFunctionCallString("required")
|
||||
} else if tc == "none" {
|
||||
// Don't use tools
|
||||
return nil, false
|
||||
}
|
||||
// "auto" is the default - let model decide
|
||||
case map[string]interface{}:
|
||||
// Specific tool selection: {"type": "tool", "name": "tool_name"}
|
||||
if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
|
||||
if name, ok := tc["name"].(string); ok {
|
||||
// Force specific tool
|
||||
cfg.SetFunctionCallString(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openresponses"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -84,6 +85,16 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
})
|
||||
}
|
||||
}
|
||||
if settings.OpenResponsesStoreTTL != nil {
|
||||
if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" {
|
||||
if _, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
|
||||
Success: false,
|
||||
Error: "Invalid open_responses_store_ttl format: " + err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save to file
|
||||
if appConfig.DynamicConfigsDir == "" {
|
||||
@@ -144,6 +155,22 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
xlog.Info("Updated LRU eviction retry settings", "maxRetries", maxRetries, "retryInterval", retryInterval)
|
||||
}
|
||||
|
||||
// Update Open Responses store TTL dynamically
|
||||
if settings.OpenResponsesStoreTTL != nil {
|
||||
ttl := time.Duration(0)
|
||||
if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" {
|
||||
if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil {
|
||||
ttl = dur
|
||||
} else {
|
||||
xlog.Warn("Invalid Open Responses store TTL format", "ttl", *settings.OpenResponsesStoreTTL, "error", err)
|
||||
}
|
||||
}
|
||||
// Import the store package
|
||||
store := openresponses.GetGlobalStore()
|
||||
store.SetTTL(ttl)
|
||||
xlog.Info("Updated Open Responses store TTL", "ttl", ttl)
|
||||
}
|
||||
|
||||
// Check if agent job retention changed
|
||||
agentJobChanged := settings.AgentJobRetentionDays != nil
|
||||
|
||||
|
||||
@@ -167,6 +167,16 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
xlog.Debug("VideoEndpoint: Calling VideoGeneration",
|
||||
"num_frames", input.NumFrames,
|
||||
"fps", input.FPS,
|
||||
"cfg_scale", input.CFGScale,
|
||||
"step", input.Step,
|
||||
"seed", input.Seed,
|
||||
"width", width,
|
||||
"height", height,
|
||||
"negative_prompt", input.NegativePrompt)
|
||||
|
||||
fn, err := backend.VideoGeneration(
|
||||
height,
|
||||
width,
|
||||
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -34,11 +36,64 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
// Track accumulated content for reasoning extraction
|
||||
accumulatedContent := ""
|
||||
lastEmittedReasoning := ""
|
||||
lastEmittedCleanedContent := ""
|
||||
|
||||
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedContent += s
|
||||
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
// Calculate new reasoning delta (what we haven't emitted yet)
|
||||
var reasoningDelta *string
|
||||
if currentReasoning != lastEmittedReasoning {
|
||||
// Extract only the new part
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
newReasoning := currentReasoning[len(lastEmittedReasoning):]
|
||||
reasoningDelta = &newReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != "" {
|
||||
// If reasoning changed in a non-append way, emit the full current reasoning
|
||||
reasoningDelta = ¤tReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate content delta from cleaned content
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
// If cleaned content changed but not in a simple append, extract delta from cleaned content
|
||||
// This handles cases where thinking tags are removed mid-stream
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
// Content changed in non-append way, use the new cleaned content
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
// Only emit content if there's actual content (not just thinking tags)
|
||||
// If deltaContent is empty, we still emit the response but with empty content
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
@@ -49,11 +104,20 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
// Only include content if there's actual content (not just thinking tags)
|
||||
if deltaContent != "" {
|
||||
delta.Content = &deltaContent
|
||||
}
|
||||
if reasoningDelta != nil && *reasoningDelta != "" {
|
||||
delta.Reasoning = reasoningDelta
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0, FinishReason: nil}},
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
@@ -65,6 +129,15 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
@@ -176,6 +249,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Prepend thinking token if needed, then extract reasoning before processing tool calls
|
||||
reasoning, result := reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
@@ -208,11 +284,20 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
var deltaReasoning *string
|
||||
if reasoning != "" {
|
||||
deltaReasoning = &reasoning
|
||||
}
|
||||
delta := &schema.Message{Content: &result}
|
||||
if deltaReasoning != nil {
|
||||
delta.Reasoning = deltaReasoning
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0, FinishReason: nil}},
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
@@ -551,12 +636,29 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
// no streaming mode
|
||||
default:
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate() // TODO: this should be the parsed jinja template. But for now this is the best we can do.
|
||||
} else {
|
||||
template = predInput
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template)
|
||||
|
||||
tokenCallback := func(s string, c *[]schema.Choice) {
|
||||
// Prepend thinking token if needed, then extract reasoning from the response
|
||||
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
message := &schema.Message{Role: "assistant", Content: &s}
|
||||
if reasoning != "" {
|
||||
message.Reasoning = &reasoning
|
||||
}
|
||||
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -575,9 +677,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &result}
|
||||
if reasoning != "" {
|
||||
message.Reasoning = &reasoning
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||
Message: message})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
@@ -586,6 +692,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if reasoning != "" {
|
||||
toolChoice.Message.Reasoning = &reasoning
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
@@ -606,16 +715,20 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
} else {
|
||||
// otherwise we return more choices directly (deprecated)
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if reasoning != "" {
|
||||
message.Reasoning = &reasoning
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
},
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
body := make([]byte, 0)
|
||||
if c.Request().Body != nil {
|
||||
c.Request().Body.Read(body)
|
||||
}
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into context so localai.VideoEndpoint can consume it
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
}
|
||||
|
||||
// VideoEndpoint godoc
|
||||
// @Summary Generate a video from an OpenAI-compatible request
|
||||
// @Description Accepts an OpenAI-style request and delegates to the LocalAI video generator
|
||||
// @Tags openai
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.OpenAIRequest true "OpenAI-style request"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Failure 400 {object} map[string]interface{}
|
||||
// @Router /v1/videos [post]
|
||||
|
||||
func MapOpenAIToVideo(input *schema.OpenAIRequest, raw map[string]interface{}) *schema.VideoRequest {
|
||||
vr := &schema.VideoRequest{}
|
||||
if input == nil {
|
||||
return vr
|
||||
}
|
||||
|
||||
if input.Model != "" {
|
||||
vr.Model = input.Model
|
||||
}
|
||||
|
||||
// Prompt mapping
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
vr.Prompt = p
|
||||
case []interface{}:
|
||||
if len(p) > 0 {
|
||||
if s, ok := p[0].(string); ok {
|
||||
vr.Prompt = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Size
|
||||
size := input.Size
|
||||
if size == "" && raw != nil {
|
||||
if v, ok := raw["size"].(string); ok {
|
||||
size = v
|
||||
}
|
||||
}
|
||||
if size != "" {
|
||||
parts := strings.SplitN(size, "x", 2)
|
||||
if len(parts) == 2 {
|
||||
if wi, err := strconv.Atoi(parts[0]); err == nil {
|
||||
vr.Width = int32(wi)
|
||||
}
|
||||
if hi, err := strconv.Atoi(parts[1]); err == nil {
|
||||
vr.Height = int32(hi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// seconds -> num frames
|
||||
secondsStr := ""
|
||||
if raw != nil {
|
||||
if v, ok := raw["seconds"].(string); ok {
|
||||
secondsStr = v
|
||||
} else if v, ok := raw["seconds"].(float64); ok {
|
||||
secondsStr = fmt.Sprintf("%v", int(v))
|
||||
}
|
||||
}
|
||||
fps := int32(30)
|
||||
if raw != nil {
|
||||
if rawFPS, ok := raw["fps"]; ok {
|
||||
switch rf := rawFPS.(type) {
|
||||
case float64:
|
||||
fps = int32(rf)
|
||||
case string:
|
||||
if fi, err := strconv.Atoi(rf); err == nil {
|
||||
fps = int32(fi)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if secondsStr != "" {
|
||||
if secF, err := strconv.Atoi(secondsStr); err == nil {
|
||||
vr.FPS = fps
|
||||
vr.NumFrames = int32(secF) * fps
|
||||
}
|
||||
}
|
||||
|
||||
// input_reference
|
||||
if raw != nil {
|
||||
if v, ok := raw["input_reference"].(string); ok {
|
||||
vr.StartImage = v
|
||||
}
|
||||
}
|
||||
|
||||
// response format
|
||||
if input.ResponseFormat != nil {
|
||||
if rf, ok := input.ResponseFormat.(string); ok {
|
||||
vr.ResponseFormat = rf
|
||||
}
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
vr.Step = int32(input.Step)
|
||||
}
|
||||
|
||||
return vr
|
||||
}
|
||||
3669
core/http/endpoints/openresponses/responses.go
Normal file
3669
core/http/endpoints/openresponses/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
453
core/http/endpoints/openresponses/store.go
Normal file
453
core/http/endpoints/openresponses/store.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package openresponses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ResponseStore provides thread-safe storage for Open Responses API responses
|
||||
type ResponseStore struct {
|
||||
mu sync.RWMutex
|
||||
responses map[string]*StoredResponse
|
||||
ttl time.Duration // Time-to-live for stored responses (0 = no expiration)
|
||||
cleanupCtx context.Context
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// StreamedEvent represents a buffered SSE event for streaming resume
|
||||
type StreamedEvent struct {
|
||||
SequenceNumber int `json:"sequence_number"`
|
||||
EventType string `json:"event_type"`
|
||||
Data []byte `json:"data"` // JSON-serialized event
|
||||
}
|
||||
|
||||
// StoredResponse contains a complete response with its input request and output items
|
||||
type StoredResponse struct {
|
||||
Request *schema.OpenResponsesRequest
|
||||
Response *schema.ORResponseResource
|
||||
Items map[string]*schema.ORItemField // item_id -> item mapping for quick lookup
|
||||
StoredAt time.Time
|
||||
ExpiresAt *time.Time // nil if no expiration
|
||||
|
||||
// Background execution support
|
||||
CancelFunc context.CancelFunc // For cancellation of background tasks
|
||||
StreamEvents []StreamedEvent // Buffered events for streaming resume
|
||||
StreamEnabled bool // Was created with stream=true
|
||||
IsBackground bool // Was created with background=true
|
||||
EventsChan chan struct{} // Signals new events for live subscribers
|
||||
mu sync.RWMutex // Protect concurrent access to this response
|
||||
}
|
||||
|
||||
var (
|
||||
globalStore *ResponseStore
|
||||
storeOnce sync.Once
|
||||
)
|
||||
|
||||
// GetGlobalStore returns the singleton response store instance
|
||||
func GetGlobalStore() *ResponseStore {
|
||||
storeOnce.Do(func() {
|
||||
globalStore = NewResponseStore(0) // Default: no TTL, will be updated from appConfig
|
||||
})
|
||||
return globalStore
|
||||
}
|
||||
|
||||
// SetTTL updates the TTL for the store
|
||||
// This will affect all new responses stored after this call
|
||||
func (s *ResponseStore) SetTTL(ttl time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Stop existing cleanup loop if running
|
||||
if s.cleanupCancel != nil {
|
||||
s.cleanupCancel()
|
||||
s.cleanupCancel = nil
|
||||
s.cleanupCtx = nil
|
||||
}
|
||||
|
||||
s.ttl = ttl
|
||||
|
||||
// If TTL > 0, start cleanup loop
|
||||
if ttl > 0 {
|
||||
s.cleanupCtx, s.cleanupCancel = context.WithCancel(context.Background())
|
||||
go s.cleanupLoop(s.cleanupCtx)
|
||||
}
|
||||
|
||||
xlog.Debug("Updated Open Responses store TTL", "ttl", ttl, "cleanup_running", ttl > 0)
|
||||
}
|
||||
|
||||
// NewResponseStore creates a new response store with optional TTL
|
||||
// If ttl is 0, responses are stored indefinitely
|
||||
func NewResponseStore(ttl time.Duration) *ResponseStore {
|
||||
store := &ResponseStore{
|
||||
responses: make(map[string]*StoredResponse),
|
||||
ttl: ttl,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine if TTL is set
|
||||
if ttl > 0 {
|
||||
store.cleanupCtx, store.cleanupCancel = context.WithCancel(context.Background())
|
||||
go store.cleanupLoop(store.cleanupCtx)
|
||||
}
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Store stores a response with its request and items
|
||||
func (s *ResponseStore) Store(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Build item index for quick lookup
|
||||
items := make(map[string]*schema.ORItemField)
|
||||
for i := range response.Output {
|
||||
item := &response.Output[i]
|
||||
if item.ID != "" {
|
||||
items[item.ID] = item
|
||||
}
|
||||
}
|
||||
|
||||
stored := &StoredResponse{
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
}
|
||||
|
||||
// Set expiration if TTL is configured
|
||||
if s.ttl > 0 {
|
||||
expiresAt := time.Now().Add(s.ttl)
|
||||
stored.ExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
s.responses[responseID] = stored
|
||||
xlog.Debug("Stored Open Responses response", "response_id", responseID, "items_count", len(items))
|
||||
}
|
||||
|
||||
// Get retrieves a stored response by ID
|
||||
func (s *ResponseStore) Get(responseID string) (*StoredResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
stored, exists := s.responses[responseID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if stored.ExpiresAt != nil && time.Now().After(*stored.ExpiresAt) {
|
||||
// Expired, but we'll return it anyway and let caller handle cleanup
|
||||
return nil, fmt.Errorf("response expired: %s", responseID)
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
// GetItem retrieves a specific item from a stored response
|
||||
func (s *ResponseStore) GetItem(responseID, itemID string) (*schema.ORItemField, error) {
|
||||
stored, err := s.Get(responseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
item, exists := stored.Items[itemID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("item not found: %s in response %s", itemID, responseID)
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// FindItem searches for an item across all stored responses
|
||||
// Returns the item and the response ID it was found in
|
||||
func (s *ResponseStore) FindItem(itemID string) (*schema.ORItemField, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
for responseID, stored := range s.responses {
|
||||
// Skip expired responses
|
||||
if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
if item, exists := stored.Items[itemID]; exists {
|
||||
return item, responseID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, "", fmt.Errorf("item not found in any stored response: %s", itemID)
|
||||
}
|
||||
|
||||
// Delete removes a response from storage
|
||||
func (s *ResponseStore) Delete(responseID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.responses, responseID)
|
||||
xlog.Debug("Deleted Open Responses response", "response_id", responseID)
|
||||
}
|
||||
|
||||
// Cleanup removes expired responses
|
||||
func (s *ResponseStore) Cleanup() int {
|
||||
if s.ttl == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
count := 0
|
||||
for id, stored := range s.responses {
|
||||
if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) {
|
||||
delete(s.responses, id)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
xlog.Debug("Cleaned up expired Open Responses", "count", count)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired responses
|
||||
func (s *ResponseStore) cleanupLoop(ctx context.Context) {
|
||||
if s.ttl == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(s.ttl / 2) // Cleanup at half TTL interval
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
xlog.Debug("Stopped Open Responses store cleanup loop")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.Cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of stored responses
|
||||
func (s *ResponseStore) Count() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.responses)
|
||||
}
|
||||
|
||||
// StoreBackground stores a background response with cancel function and optional streaming support
|
||||
func (s *ResponseStore) StoreBackground(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource, cancelFunc context.CancelFunc, streamEnabled bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Build item index for quick lookup
|
||||
items := make(map[string]*schema.ORItemField)
|
||||
for i := range response.Output {
|
||||
item := &response.Output[i]
|
||||
if item.ID != "" {
|
||||
items[item.ID] = item
|
||||
}
|
||||
}
|
||||
|
||||
stored := &StoredResponse{
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
CancelFunc: cancelFunc,
|
||||
StreamEvents: []StreamedEvent{},
|
||||
StreamEnabled: streamEnabled,
|
||||
IsBackground: true,
|
||||
EventsChan: make(chan struct{}, 100), // Buffered channel for event notifications
|
||||
}
|
||||
|
||||
// Set expiration if TTL is configured
|
||||
if s.ttl > 0 {
|
||||
expiresAt := time.Now().Add(s.ttl)
|
||||
stored.ExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
s.responses[responseID] = stored
|
||||
xlog.Debug("Stored background Open Responses response", "response_id", responseID, "stream_enabled", streamEnabled)
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of a stored response
|
||||
func (s *ResponseStore) UpdateStatus(responseID string, status string, completedAt *int64) error {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
stored.mu.Lock()
|
||||
defer stored.mu.Unlock()
|
||||
|
||||
stored.Response.Status = status
|
||||
stored.Response.CompletedAt = completedAt
|
||||
|
||||
xlog.Debug("Updated response status", "response_id", responseID, "status", status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateResponse updates the entire response object for a stored response
|
||||
func (s *ResponseStore) UpdateResponse(responseID string, response *schema.ORResponseResource) error {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
stored.mu.Lock()
|
||||
defer stored.mu.Unlock()
|
||||
|
||||
// Rebuild item index
|
||||
items := make(map[string]*schema.ORItemField)
|
||||
for i := range response.Output {
|
||||
item := &response.Output[i]
|
||||
if item.ID != "" {
|
||||
items[item.ID] = item
|
||||
}
|
||||
}
|
||||
|
||||
stored.Response = response
|
||||
stored.Items = items
|
||||
|
||||
xlog.Debug("Updated response", "response_id", responseID, "status", response.Status, "items_count", len(items))
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppendEvent appends a streaming event to the buffer for resume support
|
||||
func (s *ResponseStore) AppendEvent(responseID string, event *schema.ORStreamEvent) error {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
// Serialize the event
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event: %w", err)
|
||||
}
|
||||
|
||||
stored.mu.Lock()
|
||||
stored.StreamEvents = append(stored.StreamEvents, StreamedEvent{
|
||||
SequenceNumber: event.SequenceNumber,
|
||||
EventType: event.Type,
|
||||
Data: data,
|
||||
})
|
||||
stored.mu.Unlock()
|
||||
|
||||
// Notify any subscribers of new event
|
||||
select {
|
||||
case stored.EventsChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full, subscribers will catch up
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEventsAfter returns all events with sequence number greater than startingAfter
|
||||
func (s *ResponseStore) GetEventsAfter(responseID string, startingAfter int) ([]StreamedEvent, error) {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
stored.mu.RLock()
|
||||
defer stored.mu.RUnlock()
|
||||
|
||||
var result []StreamedEvent
|
||||
for _, event := range stored.StreamEvents {
|
||||
if event.SequenceNumber > startingAfter {
|
||||
result = append(result, event)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Cancel cancels a background response if it's still in progress
|
||||
func (s *ResponseStore) Cancel(responseID string) (*schema.ORResponseResource, error) {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
stored.mu.Lock()
|
||||
defer stored.mu.Unlock()
|
||||
|
||||
// If already in a terminal state, just return the response (idempotent)
|
||||
status := stored.Response.Status
|
||||
if status == schema.ORStatusCompleted || status == schema.ORStatusFailed ||
|
||||
status == schema.ORStatusIncomplete || status == schema.ORStatusCancelled {
|
||||
xlog.Debug("Response already in terminal state", "response_id", responseID, "status", status)
|
||||
return stored.Response, nil
|
||||
}
|
||||
|
||||
// Cancel the context if available
|
||||
if stored.CancelFunc != nil {
|
||||
stored.CancelFunc()
|
||||
xlog.Debug("Cancelled background response", "response_id", responseID)
|
||||
}
|
||||
|
||||
// Update status to cancelled
|
||||
now := time.Now().Unix()
|
||||
stored.Response.Status = schema.ORStatusCancelled
|
||||
stored.Response.CompletedAt = &now
|
||||
|
||||
return stored.Response, nil
|
||||
}
|
||||
|
||||
// GetEventsChan returns the events notification channel for a response
|
||||
func (s *ResponseStore) GetEventsChan(responseID string) (chan struct{}, error) {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
return stored.EventsChan, nil
|
||||
}
|
||||
|
||||
// IsStreamEnabled checks if a response was created with streaming enabled
|
||||
func (s *ResponseStore) IsStreamEnabled(responseID string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, fmt.Errorf("response not found: %s", responseID)
|
||||
}
|
||||
|
||||
stored.mu.RLock()
|
||||
defer stored.mu.RUnlock()
|
||||
|
||||
return stored.StreamEnabled, nil
|
||||
}
|
||||
13
core/http/endpoints/openresponses/store_suite_test.go
Normal file
13
core/http/endpoints/openresponses/store_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package openresponses
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "ResponseStore Suite")
|
||||
}
|
||||
626
core/http/endpoints/openresponses/store_test.go
Normal file
626
core/http/endpoints/openresponses/store_test.go
Normal file
@@ -0,0 +1,626 @@
|
||||
package openresponses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ResponseStore", func() {
|
||||
var store *ResponseStore
|
||||
|
||||
BeforeEach(func() {
|
||||
store = NewResponseStore(0) // No TTL for most tests
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up
|
||||
})
|
||||
|
||||
Describe("Store and Get", func() {
|
||||
It("should store and retrieve a response", func() {
|
||||
responseID := "resp_test123"
|
||||
request := &schema.OpenResponsesRequest{
|
||||
Model: "test-model",
|
||||
Input: "Hello",
|
||||
}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
Status: "completed",
|
||||
Model: "test-model",
|
||||
Output: []schema.ORItemField{
|
||||
{
|
||||
Type: "message",
|
||||
ID: "msg_123",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{{
|
||||
Type: "output_text",
|
||||
Text: "Hello, world!",
|
||||
Annotations: []schema.ORAnnotation{},
|
||||
Logprobs: []schema.ORLogProb{},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored).ToNot(BeNil())
|
||||
Expect(stored.Response.ID).To(Equal(responseID))
|
||||
Expect(stored.Request.Model).To(Equal("test-model"))
|
||||
Expect(len(stored.Items)).To(Equal(1))
|
||||
Expect(stored.Items["msg_123"]).ToNot(BeNil())
|
||||
Expect(stored.Items["msg_123"].ID).To(Equal("msg_123"))
|
||||
})
|
||||
|
||||
It("should return error for non-existent response", func() {
|
||||
_, err := store.Get("nonexistent")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("not found"))
|
||||
})
|
||||
|
||||
It("should index all items by ID", func() {
|
||||
responseID := "resp_test456"
|
||||
request := &schema.OpenResponsesRequest{
|
||||
Model: "test-model",
|
||||
Input: "Test",
|
||||
}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{
|
||||
Type: "message",
|
||||
ID: "msg_1",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
},
|
||||
{
|
||||
Type: "function_call",
|
||||
ID: "fc_1",
|
||||
Status: "completed",
|
||||
CallID: "fc_1",
|
||||
Name: "test_function",
|
||||
Arguments: `{"arg": "value"}`,
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
ID: "msg_2",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(stored.Items)).To(Equal(3))
|
||||
Expect(stored.Items["msg_1"]).ToNot(BeNil())
|
||||
Expect(stored.Items["fc_1"]).ToNot(BeNil())
|
||||
Expect(stored.Items["msg_2"]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("should handle items without IDs", func() {
|
||||
responseID := "resp_test789"
|
||||
request := &schema.OpenResponsesRequest{
|
||||
Model: "test-model",
|
||||
Input: "Test",
|
||||
}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{
|
||||
Type: "message",
|
||||
ID: "", // No ID
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
ID: "msg_with_id",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Only items with IDs are indexed
|
||||
Expect(len(stored.Items)).To(Equal(1))
|
||||
Expect(stored.Items["msg_with_id"]).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetItem", func() {
|
||||
It("should retrieve a specific item by ID", func() {
|
||||
responseID := "resp_item_test"
|
||||
itemID := "msg_specific"
|
||||
request := &schema.OpenResponsesRequest{
|
||||
Model: "test-model",
|
||||
Input: "Test",
|
||||
}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{
|
||||
Type: "message",
|
||||
ID: itemID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{{
|
||||
Type: "output_text",
|
||||
Text: "Specific message",
|
||||
Annotations: []schema.ORAnnotation{},
|
||||
Logprobs: []schema.ORLogProb{},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
item, err := store.GetItem(responseID, itemID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(item).ToNot(BeNil())
|
||||
Expect(item.ID).To(Equal(itemID))
|
||||
Expect(item.Type).To(Equal("message"))
|
||||
})
|
||||
|
||||
It("should return error for non-existent item", func() {
|
||||
responseID := "resp_item_test2"
|
||||
request := &schema.OpenResponsesRequest{
|
||||
Model: "test-model",
|
||||
Input: "Test",
|
||||
}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{
|
||||
Type: "message",
|
||||
ID: "msg_existing",
|
||||
Status: "completed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
_, err := store.GetItem(responseID, "nonexistent_item")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("item not found"))
|
||||
})
|
||||
|
||||
It("should return error for non-existent response when getting item", func() {
|
||||
_, err := store.GetItem("nonexistent_response", "any_item")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("response not found"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("FindItem", func() {
|
||||
It("should find an item across all stored responses", func() {
|
||||
// Store first response
|
||||
responseID1 := "resp_find_1"
|
||||
itemID1 := "msg_find_1"
|
||||
store.Store(responseID1, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{
|
||||
ID: responseID1,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{Type: "message", ID: itemID1, Status: "completed"},
|
||||
},
|
||||
})
|
||||
|
||||
// Store second response
|
||||
responseID2 := "resp_find_2"
|
||||
itemID2 := "msg_find_2"
|
||||
store.Store(responseID2, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{
|
||||
ID: responseID2,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{Type: "message", ID: itemID2, Status: "completed"},
|
||||
},
|
||||
})
|
||||
|
||||
// Find item from first response
|
||||
item, foundResponseID, err := store.FindItem(itemID1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(item).ToNot(BeNil())
|
||||
Expect(item.ID).To(Equal(itemID1))
|
||||
Expect(foundResponseID).To(Equal(responseID1))
|
||||
|
||||
// Find item from second response
|
||||
item, foundResponseID, err = store.FindItem(itemID2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(item).ToNot(BeNil())
|
||||
Expect(item.ID).To(Equal(itemID2))
|
||||
Expect(foundResponseID).To(Equal(responseID2))
|
||||
})
|
||||
|
||||
It("should return error when item not found in any response", func() {
|
||||
_, _, err := store.FindItem("nonexistent_item")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("item not found in any stored response"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Delete", func() {
|
||||
It("should delete a stored response", func() {
|
||||
responseID := "resp_delete_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
Expect(store.Count()).To(Equal(1))
|
||||
|
||||
store.Delete(responseID)
|
||||
Expect(store.Count()).To(Equal(0))
|
||||
|
||||
_, err := store.Get(responseID)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should handle deleting non-existent response gracefully", func() {
|
||||
// Should not panic
|
||||
store.Delete("nonexistent")
|
||||
Expect(store.Count()).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("should return correct count of stored responses", func() {
|
||||
Expect(store.Count()).To(Equal(0))
|
||||
|
||||
store.Store("resp_1", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_1", Object: "response"})
|
||||
Expect(store.Count()).To(Equal(1))
|
||||
|
||||
store.Store("resp_2", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_2", Object: "response"})
|
||||
Expect(store.Count()).To(Equal(2))
|
||||
|
||||
store.Delete("resp_1")
|
||||
Expect(store.Count()).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("TTL and Expiration", func() {
|
||||
It("should set expiration when TTL is configured", func() {
|
||||
ttlStore := NewResponseStore(100 * time.Millisecond)
|
||||
responseID := "resp_ttl_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response"}
|
||||
|
||||
ttlStore.Store(responseID, request, response)
|
||||
|
||||
stored, err := ttlStore.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored.ExpiresAt).ToNot(BeNil())
|
||||
Expect(stored.ExpiresAt.After(time.Now())).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not set expiration when TTL is 0", func() {
|
||||
responseID := "resp_no_ttl"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response"}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored.ExpiresAt).To(BeNil())
|
||||
})
|
||||
|
||||
It("should clean up expired responses", func() {
|
||||
ttlStore := NewResponseStore(50 * time.Millisecond)
|
||||
responseID := "resp_expire_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response"}
|
||||
|
||||
ttlStore.Store(responseID, request, response)
|
||||
Expect(ttlStore.Count()).To(Equal(1))
|
||||
|
||||
// Wait for expiration (longer than TTL and cleanup interval)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Cleanup should remove expired response (may have already been cleaned by goroutine)
|
||||
count := ttlStore.Cleanup()
|
||||
// Count might be 0 if cleanup goroutine already ran, or 1 if we're first
|
||||
Expect(count).To(BeNumerically(">=", 0))
|
||||
Expect(ttlStore.Count()).To(Equal(0))
|
||||
|
||||
_, err := ttlStore.Get(responseID)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should return error for expired response", func() {
|
||||
ttlStore := NewResponseStore(50 * time.Millisecond)
|
||||
responseID := "resp_expire_error"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response"}
|
||||
|
||||
ttlStore.Store(responseID, request, response)
|
||||
|
||||
// Wait for expiration (but not long enough for cleanup goroutine to remove it)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
// Try to get before cleanup goroutine removes it
|
||||
_, err := ttlStore.Get(responseID)
|
||||
// Error could be "expired" or "not found" (if cleanup already ran)
|
||||
Expect(err).To(HaveOccurred())
|
||||
// Either error message is acceptable
|
||||
errMsg := err.Error()
|
||||
Expect(errMsg).To(Or(ContainSubstring("expired"), ContainSubstring("not found")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Thread Safety", func() {
|
||||
It("should handle concurrent stores and gets", func() {
|
||||
// This is a basic concurrency test
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
responseID := fmt.Sprintf("resp_concurrent_%d", id)
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Output: []schema.ORItemField{
|
||||
{Type: "message", ID: fmt.Sprintf("msg_%d", id), Status: "completed"},
|
||||
},
|
||||
}
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
// Retrieve immediately
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored).ToNot(BeNil())
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
Expect(store.Count()).To(Equal(10))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetGlobalStore", func() {
|
||||
It("should return singleton instance", func() {
|
||||
store1 := GetGlobalStore()
|
||||
store2 := GetGlobalStore()
|
||||
Expect(store1).To(Equal(store2))
|
||||
})
|
||||
|
||||
It("should persist data across GetGlobalStore calls", func() {
|
||||
globalStore := GetGlobalStore()
|
||||
responseID := "resp_global_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response"}
|
||||
|
||||
globalStore.Store(responseID, request, response)
|
||||
|
||||
// Get store again
|
||||
globalStore2 := GetGlobalStore()
|
||||
stored, err := globalStore2.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Background Mode Support", func() {
|
||||
It("should store background response with cancel function", func() {
|
||||
responseID := "resp_bg_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusQueued,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, true)
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored).ToNot(BeNil())
|
||||
Expect(stored.IsBackground).To(BeTrue())
|
||||
Expect(stored.StreamEnabled).To(BeTrue())
|
||||
Expect(stored.CancelFunc).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("should update status of stored response", func() {
|
||||
responseID := "resp_status_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusQueued,
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
err := store.UpdateStatus(responseID, schema.ORStatusInProgress, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored.Response.Status).To(Equal(schema.ORStatusInProgress))
|
||||
})
|
||||
|
||||
It("should append and retrieve streaming events", func() {
|
||||
responseID := "resp_events_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusInProgress,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, true)
|
||||
|
||||
// Append events
|
||||
event1 := &schema.ORStreamEvent{
|
||||
Type: "response.created",
|
||||
SequenceNumber: 0,
|
||||
}
|
||||
event2 := &schema.ORStreamEvent{
|
||||
Type: "response.in_progress",
|
||||
SequenceNumber: 1,
|
||||
}
|
||||
event3 := &schema.ORStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
SequenceNumber: 2,
|
||||
}
|
||||
|
||||
err := store.AppendEvent(responseID, event1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = store.AppendEvent(responseID, event2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = store.AppendEvent(responseID, event3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Get all events after -1 (all events)
|
||||
events, err := store.GetEventsAfter(responseID, -1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(events).To(HaveLen(3))
|
||||
|
||||
// Get events after sequence 1
|
||||
events, err = store.GetEventsAfter(responseID, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].SequenceNumber).To(Equal(2))
|
||||
})
|
||||
|
||||
It("should cancel an in-progress response", func() {
|
||||
responseID := "resp_cancel_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusInProgress,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, false)
|
||||
|
||||
// Cancel the response
|
||||
cancelledResponse, err := store.Cancel(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCancelled))
|
||||
Expect(cancelledResponse.CompletedAt).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("should be idempotent when cancelling already completed response", func() {
|
||||
responseID := "resp_idempotent_cancel"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
completedAt := time.Now().Unix()
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusCompleted,
|
||||
CompletedAt: &completedAt,
|
||||
}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
|
||||
// Try to cancel a completed response
|
||||
cancelledResponse, err := store.Cancel(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Status should remain completed (not changed to cancelled)
|
||||
Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCompleted))
|
||||
})
|
||||
|
||||
It("should check if streaming is enabled", func() {
|
||||
responseID := "resp_stream_check"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusQueued,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, true)
|
||||
|
||||
enabled, err := store.IsStreamEnabled(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(enabled).To(BeTrue())
|
||||
|
||||
// Store another without streaming
|
||||
responseID2 := "resp_no_stream"
|
||||
store.StoreBackground(responseID2, request, response, cancel, false)
|
||||
|
||||
enabled2, err := store.IsStreamEnabled(responseID2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(enabled2).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should notify subscribers of new events", func() {
|
||||
responseID := "resp_events_chan"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusInProgress,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, true)
|
||||
|
||||
eventsChan, err := store.GetEventsChan(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(eventsChan).ToNot(BeNil())
|
||||
|
||||
// Append an event
|
||||
event := &schema.ORStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
SequenceNumber: 0,
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
store.AppendEvent(responseID, event)
|
||||
}()
|
||||
|
||||
// Wait for notification
|
||||
select {
|
||||
case <-eventsChan:
|
||||
// Event received
|
||||
case <-time.After(1 * time.Second):
|
||||
Fail("Timeout waiting for event notification")
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,13 +1,33 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
tmpdir string
|
||||
modelDir string
|
||||
)
|
||||
|
||||
func TestLocalAI(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
modelDir = filepath.Join(tmpdir, "models")
|
||||
err = os.Mkdir(modelDir, 0750)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
AfterSuite(func() {
|
||||
err := os.RemoveAll(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
RunSpecs(t, "LocalAI HTTP test suite")
|
||||
}
|
||||
|
||||
@@ -484,3 +484,103 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error {
|
||||
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID (Open Responses uses x-request-id)
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
if correlationID == "" {
|
||||
correlationID = uuid.New().String()
|
||||
}
|
||||
c.Response().Header().Set("x-request-id", correlationID)
|
||||
|
||||
// Use the request context directly - Echo properly supports context cancellation!
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
|
||||
// Cancel when request context is cancelled (client disconnects)
|
||||
go func() {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
cancel()
|
||||
case <-c1.Done():
|
||||
// Already cancelled
|
||||
}
|
||||
}()
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenResponsesRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeOpenResponsesRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error {
|
||||
// Temperature
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
// TopP
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
// MaxOutputTokens -> Maxtokens
|
||||
if input.MaxOutputTokens != nil {
|
||||
config.Maxtokens = input.MaxOutputTokens
|
||||
}
|
||||
|
||||
// Convert tools to functions - this will be handled in the endpoint handler
|
||||
// We just validate that tools are present if needed
|
||||
|
||||
// Handle tool_choice
|
||||
if input.ToolChoice != nil {
|
||||
switch tc := input.ToolChoice.(type) {
|
||||
case string:
|
||||
// "auto", "required", or "none"
|
||||
if tc == "required" {
|
||||
config.SetFunctionCallString("required")
|
||||
} else if tc == "none" {
|
||||
// Don't use tools - handled in endpoint
|
||||
}
|
||||
// "auto" is default - let model decide
|
||||
case map[string]interface{}:
|
||||
// Specific tool: {type:"function", name:"..."}
|
||||
if tcType, ok := tc["type"].(string); ok && tcType == "function" {
|
||||
if name, ok := tc["name"].(string); ok {
|
||||
config.SetFunctionCallString(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid, _ := config.Validate(); valid {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type APIExchange struct {
|
||||
var traceBuffer *circularbuffer.Queue[APIExchange]
|
||||
var mu sync.Mutex
|
||||
var logChan = make(chan APIExchange, 100)
|
||||
var initOnce sync.Once
|
||||
|
||||
type bodyWriter struct {
|
||||
http.ResponseWriter
|
||||
@@ -53,26 +54,37 @@ func (w *bodyWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// TraceMiddleware intercepts and logs JSON API requests and responses
|
||||
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
if app.ApplicationConfig().EnableTracing && traceBuffer == nil {
|
||||
traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems)
|
||||
func initializeTracing(maxItems int) {
|
||||
initOnce.Do(func() {
|
||||
if maxItems <= 0 {
|
||||
maxItems = 100
|
||||
}
|
||||
mu.Lock()
|
||||
traceBuffer = circularbuffer.New[APIExchange](maxItems)
|
||||
mu.Unlock()
|
||||
|
||||
go func() {
|
||||
for exchange := range logChan {
|
||||
mu.Lock()
|
||||
traceBuffer.Enqueue(exchange)
|
||||
if traceBuffer != nil {
|
||||
traceBuffer.Enqueue(exchange)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TraceMiddleware intercepts and logs JSON API requests and responses
|
||||
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if !app.ApplicationConfig().EnableTracing {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
initializeTracing(app.ApplicationConfig().TracingMaxItems)
|
||||
|
||||
if c.Request().Header.Get("Content-Type") != "application/json" {
|
||||
return next(c)
|
||||
}
|
||||
@@ -138,6 +150,10 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
// GetTraces returns a copy of the logged API exchanges for display
|
||||
func GetTraces() []APIExchange {
|
||||
mu.Lock()
|
||||
if traceBuffer == nil {
|
||||
mu.Unlock()
|
||||
return []APIExchange{}
|
||||
}
|
||||
traces := traceBuffer.Values()
|
||||
mu.Unlock()
|
||||
|
||||
@@ -151,6 +167,8 @@ func GetTraces() []APIExchange {
|
||||
// ClearTraces clears the in-memory logs
|
||||
func ClearTraces() {
|
||||
mu.Lock()
|
||||
traceBuffer.Clear()
|
||||
if traceBuffer != nil {
|
||||
traceBuffer.Clear()
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
openai "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MapOpenAIToVideo", func() {
|
||||
It("maps size and seconds correctly", func() {
|
||||
cases := []struct {
|
||||
name string
|
||||
input *schema.OpenAIRequest
|
||||
raw map[string]interface{}
|
||||
expectsW int32
|
||||
expectsH int32
|
||||
expectsF int32
|
||||
expectsN int32
|
||||
}{
|
||||
{
|
||||
name: "size in input",
|
||||
input: &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "m"},
|
||||
},
|
||||
Size: "256x128",
|
||||
},
|
||||
expectsW: 256,
|
||||
expectsH: 128,
|
||||
},
|
||||
{
|
||||
name: "size in raw and seconds as string",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"size": "720x480", "seconds": "2"},
|
||||
expectsW: 720,
|
||||
expectsH: 480,
|
||||
expectsF: 30,
|
||||
expectsN: 60,
|
||||
},
|
||||
{
|
||||
name: "seconds as number and fps override",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"seconds": 3.0, "fps": 24.0},
|
||||
expectsF: 24,
|
||||
expectsN: 72,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
By(c.name)
|
||||
vr := openai.MapOpenAIToVideo(c.input, c.raw)
|
||||
if c.expectsW != 0 {
|
||||
Expect(vr.Width).To(Equal(c.expectsW))
|
||||
}
|
||||
if c.expectsH != 0 {
|
||||
Expect(vr.Height).To(Equal(c.expectsH))
|
||||
}
|
||||
if c.expectsF != 0 {
|
||||
Expect(vr.FPS).To(Equal(c.expectsF))
|
||||
}
|
||||
if c.expectsN != 0 {
|
||||
Expect(vr.NumFrames).To(Equal(c.expectsN))
|
||||
}
|
||||
|
||||
b, err := json.Marshal(vr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_ = b
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const testAPIKey = "joshua"
|
||||
|
||||
type fakeAI struct{}
|
||||
|
||||
func (f *fakeAI) Busy() bool { return false }
|
||||
func (f *fakeAI) Lock() {}
|
||||
func (f *fakeAI) Unlock() {}
|
||||
func (f *fakeAI) Locking() bool { return false }
|
||||
func (f *fakeAI) Predict(*pb.PredictOptions) (string, error) { return "", nil }
|
||||
func (f *fakeAI) PredictStream(*pb.PredictOptions, chan string) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeAI) Load(*pb.ModelOptions) error { return nil }
|
||||
func (f *fakeAI) Embeddings(*pb.PredictOptions) ([]float32, error) { return nil, nil }
|
||||
func (f *fakeAI) GenerateImage(*pb.GenerateImageRequest) error { return nil }
|
||||
func (f *fakeAI) GenerateVideo(*pb.GenerateVideoRequest) error { return nil }
|
||||
func (f *fakeAI) Detect(*pb.DetectOptions) (pb.DetectResponse, error) { return pb.DetectResponse{}, nil }
|
||||
func (f *fakeAI) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
return pb.TranscriptResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) TTS(*pb.TTSRequest) error { return nil }
|
||||
func (f *fakeAI) SoundGeneration(*pb.SoundGenerationRequest) error { return nil }
|
||||
func (f *fakeAI) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
return pb.TokenizationResponse{}, nil
|
||||
}
|
||||
func (f *fakeAI) Status() (pb.StatusResponse, error) { return pb.StatusResponse{}, nil }
|
||||
func (f *fakeAI) StoresSet(*pb.StoresSetOptions) error { return nil }
|
||||
func (f *fakeAI) StoresDelete(*pb.StoresDeleteOptions) error { return nil }
|
||||
func (f *fakeAI) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResponse{}, nil }
|
||||
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *echo.Echo
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelDir := filepath.Join(tmpdir, "models")
|
||||
err = os.Mkdir(modelDir, 0750)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
grpc.Provide("embedded://fake", &fakeAI{})
|
||||
|
||||
appServer, err = application.New(
|
||||
config.WithContext(ctx),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithApiKeys([]string{testAPIKey}),
|
||||
config.WithGeneratedContentDir(tmpdir),
|
||||
config.WithExternalBackend("fake", "embedded://fake"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = app.Shutdown(ctx)
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
|
||||
It("accepts OpenAI-style video create and delegates to backend", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
|
||||
// Log error if needed
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
Eventually(func() error {
|
||||
req, _ := http.NewRequest("GET", "http://127.0.0.1:9091/v1/models", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("bad status: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}, "30s", "500ms").Should(Succeed())
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": "fake-model",
|
||||
"backend": "fake",
|
||||
"prompt": "a test video",
|
||||
"size": "256x256",
|
||||
"seconds": "1",
|
||||
}
|
||||
payload, err := json.Marshal(body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9091/v1/videos", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var out map[string]interface{}
|
||||
err = json.Unmarshal(dat, &out)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, ok := out["data"].([]interface{})
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(data)).To(BeNumerically(">", 0))
|
||||
first := data[0].(map[string]interface{})
|
||||
url, ok := first["url"].(string)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(url).To(ContainSubstring("/generated-videos/"))
|
||||
Expect(url).To(ContainSubstring(".mp4"))
|
||||
})
|
||||
})
|
||||
1027
core/http/openresponses_test.go
Normal file
1027
core/http/openresponses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
108
core/http/routes/anthropic.go
Normal file
108
core/http/routes/anthropic.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/anthropic"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
|
||||
// Anthropic Messages API endpoint
|
||||
messagesHandler := anthropic.MessagesEndpoint(
|
||||
application.ModelConfigLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
|
||||
messagesMiddleware := []echo.MiddlewareFunc{
|
||||
middleware.TraceMiddleware(application),
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.AnthropicRequest) }),
|
||||
setAnthropicRequestContext(application.ApplicationConfig()),
|
||||
}
|
||||
|
||||
// Main Anthropic endpoint
|
||||
app.POST("/v1/messages", messagesHandler, messagesMiddleware...)
|
||||
|
||||
// Also support without version prefix for compatibility
|
||||
app.POST("/messages", messagesHandler, messagesMiddleware...)
|
||||
}
|
||||
|
||||
// setAnthropicRequestContext sets up the context and cancel function for Anthropic requests
|
||||
func setAnthropicRequestContext(appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "model configuration not found")
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
// Anthropic uses x-request-id header
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
if correlationID == "" {
|
||||
correlationID = uuid.New().String()
|
||||
}
|
||||
c.Response().Header().Set("x-request-id", correlationID)
|
||||
|
||||
// Set up context with cancellation
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(appConfig.Context)
|
||||
|
||||
// Cancel when request context is cancelled (client disconnects)
|
||||
go func() {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
cancel()
|
||||
case <-c1.Done():
|
||||
// Already cancelled
|
||||
}
|
||||
}()
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, middleware.CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
if cfg.Model == "" {
|
||||
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
// Log the Anthropic API version if provided
|
||||
anthropicVersion := c.Request().Header.Get("anthropic-version")
|
||||
if anthropicVersion != "" {
|
||||
xlog.Debug("Anthropic API version", "version", anthropicVersion)
|
||||
}
|
||||
|
||||
// Validate max_tokens is provided
|
||||
if input.MaxTokens <= 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("max_tokens is required and must be greater than 0"))
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user