From 923ebbb3440dcd105e13097a819f7a77193ab06f Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 23 Jan 2026 15:18:41 +0100 Subject: [PATCH] feat(qwen-tts): add Qwen-tts backend (#8163) * feat(qwen-tts): add Qwen-tts backend Signed-off-by: Ettore Di Giacinto * Update intel deps Signed-off-by: Ettore Di Giacinto * Drop flash-attn for cuda13 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 91 ++++ .github/workflows/test-extra.yml | 21 +- Dockerfile | 2 +- Makefile | 8 +- README.md | 5 +- backend/index.yaml | 105 ++++ backend/python/bark/requirements-intel.txt | 8 +- .../python/chatterbox/requirements-intel.txt | 7 +- backend/python/common/libbackend.sh | 2 +- .../common/template/requirements-intel.txt | 3 +- backend/python/coqui/requirements-intel.txt | 8 +- .../python/diffusers/requirements-intel.txt | 8 +- .../faster-whisper/requirements-intel.txt | 6 +- backend/python/kokoro/requirements-intel.txt | 8 +- .../python/pocket-tts/requirements-intel.txt | 4 +- backend/python/qwen-tts/Makefile | 23 + backend/python/qwen-tts/backend.py | 475 ++++++++++++++++++ backend/python/qwen-tts/install.sh | 13 + backend/python/qwen-tts/requirements-cpu.txt | 5 + .../qwen-tts/requirements-cublas12-after.txt | 1 + .../python/qwen-tts/requirements-cublas12.txt | 5 + .../python/qwen-tts/requirements-cublas13.txt | 5 + .../python/qwen-tts/requirements-hipblas.txt | 5 + .../qwen-tts/requirements-intel-after.txt | 1 + .../python/qwen-tts/requirements-intel.txt | 5 + .../python/qwen-tts/requirements-l4t12.txt | 5 + .../python/qwen-tts/requirements-l4t13.txt | 5 + backend/python/qwen-tts/requirements-mps.txt | 4 + backend/python/qwen-tts/requirements.txt | 6 + backend/python/qwen-tts/run.sh | 9 + backend/python/qwen-tts/test.py | 98 ++++ backend/python/qwen-tts/test.sh | 11 + .../python/rerankers/requirements-intel.txt | 6 +- backend/python/rfdetr/requirements-intel.txt | 8 +- .../transformers/requirements-intel.txt | 7 +- .../python/vibevoice/requirements-intel.txt | 8 +- backend/python/vllm/requirements-intel.txt | 7 +- docs/content/features/text-to-audio.md | 82 ++- 38 files changed, 996 insertions(+), 84 deletions(-) create mode 100644 backend/python/qwen-tts/Makefile create mode 100644 backend/python/qwen-tts/backend.py create mode 100755 backend/python/qwen-tts/install.sh create mode 100644 backend/python/qwen-tts/requirements-cpu.txt create mode 100644 backend/python/qwen-tts/requirements-cublas12-after.txt create mode 100644 backend/python/qwen-tts/requirements-cublas12.txt create mode 100644 backend/python/qwen-tts/requirements-cublas13.txt create mode 100644 backend/python/qwen-tts/requirements-hipblas.txt create mode 100644 backend/python/qwen-tts/requirements-intel-after.txt create mode 100644 backend/python/qwen-tts/requirements-intel.txt create mode 100644 backend/python/qwen-tts/requirements-l4t12.txt create mode 100644 backend/python/qwen-tts/requirements-l4t13.txt create mode 100644 backend/python/qwen-tts/requirements-mps.txt create mode 100644 backend/python/qwen-tts/requirements.txt create mode 100755 backend/python/qwen-tts/run.sh create mode 100644 backend/python/qwen-tts/test.py create mode 100755 backend/python/qwen-tts/test.sh diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 413d0476b..235b94a19 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -105,6 +105,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "9" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-qwen-tts' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "9" @@ -353,6 +366,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-qwen-tts' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -431,6 +457,19 @@ jobs: backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" + - build-type: 'l4t' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-cuda-13-arm64-qwen-tts' + runs-on: 'ubuntu-24.04-arm' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + ubuntu-version: '2404' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" @@ -680,6 +719,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'hipblas' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-rocm-hipblas-qwen-tts' + runs-on: 'arc-runner-set' + base-image: "rocm/dev-ubuntu-24.04:6.4.4" + skip-drivers: 'false' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" @@ -824,6 +876,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' + - build-type: 'l4t' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-qwen-tts' + runs-on: 'ubuntu-24.04-arm' + base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" + skip-drivers: 'true' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" @@ -890,6 +955,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'intel' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-intel-qwen-tts' + runs-on: 'arc-runner-set' + base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" + skip-drivers: 'false' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" @@ -1343,6 +1421,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-qwen-tts' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "qwen-tts" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index 0d01cde73..9ef7039e6 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -284,4 +284,23 @@ jobs: - name: Test pocket-tts run: | make --jobs=5 --output-sync=target -C backend/python/pocket-tts - make --jobs=5 --output-sync=target -C backend/python/pocket-tts test \ No newline at end of file + make --jobs=5 --output-sync=target -C backend/python/pocket-tts test + tests-qwen-tts: + runs-on: ubuntu-latest + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install build-essential ffmpeg + sudo apt-get install -y ca-certificates cmake curl patch python3-pip + # Install UV + curl -LsSf https://astral.sh/uv/install.sh | sh + pip install --user --no-cache-dir grpcio-tools==1.64.1 + - name: Test qwen-tts + run: | + make --jobs=5 --output-sync=target -C backend/python/qwen-tts + make --jobs=5 --output-sync=target -C backend/python/qwen-tts test \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 29570be8b..93856d329 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates curl wget espeak-ng libgomp1 \ - ffmpeg libopenblas0 libopenblas-dev && \ + ffmpeg libopenblas0 libopenblas-dev sox && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/Makefile b/Makefile index 9bc95063e..69f0e37a5 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts backends/qwen-tts GOCMD=go GOTEST=$(GOCMD) test @@ -317,6 +317,7 @@ prepare-test-extra: protogen-python $(MAKE) -C backend/python/vibevoice $(MAKE) -C backend/python/moonshine $(MAKE) -C backend/python/pocket-tts + $(MAKE) -C backend/python/qwen-tts test-extra: prepare-test-extra $(MAKE) -C backend/python/transformers test @@ -326,6 +327,7 @@ test-extra: prepare-test-extra $(MAKE) -C backend/python/vibevoice test $(MAKE) -C backend/python/moonshine test $(MAKE) -C backend/python/pocket-tts test + $(MAKE) -C backend/python/qwen-tts test DOCKER_IMAGE?=local-ai DOCKER_AIO_IMAGE?=local-ai-aio @@ -459,6 +461,7 @@ BACKEND_CHATTERBOX = chatterbox|python|.|false|true BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true BACKEND_MOONSHINE = moonshine|python|.|false|true BACKEND_POCKET_TTS = pocket-tts|python|.|false|true +BACKEND_QWEN_TTS = qwen-tts|python|.|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) @@ -505,12 +508,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX))) $(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE))) $(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE))) $(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS))) +$(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-exllama2 docker-build-moonshine docker-build-pocket-tts +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-exllama2 docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts ######################################################## ### END Backends diff --git a/README.md b/README.md index 743f9f15c..14a23fda2 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,7 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration | **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU | | **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU | | **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU | +| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU | ### Image & Video Generation | Backend | Description | Acceleration Support | @@ -319,8 +320,8 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration |-------------------|-------------------|------------------| | **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware | | **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware | -| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts, vibevoice, pocket-tts | AMD Graphics | -| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark, vibevoice, pocket-tts | Intel Arc, Intel iGPUs | +| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts, vibevoice, pocket-tts, qwen-tts | AMD Graphics | +| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark, vibevoice, pocket-tts, qwen-tts | Intel Arc, Intel iGPUs | | **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ | | **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs | | **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (AGX Orin, etc.) | diff --git a/backend/index.yaml b/backend/index.yaml index 916d070ab..e888a98c9 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -428,6 +428,28 @@ nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice" icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4 +- &qwen-tts + urls: + - https://github.com/QwenLM/Qwen3-TTS + description: | + Qwen3-TTS is a high-quality text-to-speech model supporting custom voice, voice design, and voice cloning. + tags: + - text-to-speech + - TTS + license: apache-2.0 + name: "qwen-tts" + alias: "qwen-tts" + capabilities: + nvidia: "cuda12-qwen-tts" + intel: "intel-qwen-tts" + amd: "rocm-qwen-tts" + nvidia-l4t: "nvidia-l4t-qwen-tts" + default: "cpu-qwen-tts" + nvidia-cuda-13: "cuda13-qwen-tts" + nvidia-cuda-12: "cuda12-qwen-tts" + nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts" + icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4 - &pocket-tts urls: - https://github.com/kyutai-labs/pocket-tts @@ -1613,6 +1635,89 @@ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice +## qwen-tts +- !!merge <<: *qwen-tts + name: "qwen-tts-development" + capabilities: + nvidia: "cuda12-qwen-tts-development" + intel: "intel-qwen-tts-development" + amd: "rocm-qwen-tts-development" + nvidia-l4t: "nvidia-l4t-qwen-tts-development" + default: "cpu-qwen-tts-development" + nvidia-cuda-13: "cuda13-qwen-tts-development" + nvidia-cuda-12: "cuda12-qwen-tts-development" + nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts-development" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts-development" +- !!merge <<: *qwen-tts + name: "cpu-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-qwen-tts" + mirrors: + - localai/localai-backends:latest-cpu-qwen-tts +- !!merge <<: *qwen-tts + name: "cpu-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-qwen-tts" + mirrors: + - localai/localai-backends:master-cpu-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda12-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-qwen-tts" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda12-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-qwen-tts" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda13-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-qwen-tts" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-13-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda13-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-qwen-tts" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-13-qwen-tts +- !!merge <<: *qwen-tts + name: "intel-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-qwen-tts" + mirrors: + - localai/localai-backends:latest-gpu-intel-qwen-tts +- !!merge <<: *qwen-tts + name: "intel-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-qwen-tts" + mirrors: + - localai/localai-backends:master-gpu-intel-qwen-tts +- !!merge <<: *qwen-tts + name: "rocm-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-qwen-tts" + mirrors: + - localai/localai-backends:latest-gpu-rocm-hipblas-qwen-tts +- !!merge <<: *qwen-tts + name: "rocm-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-qwen-tts" + mirrors: + - localai/localai-backends:master-gpu-rocm-hipblas-qwen-tts +- !!merge <<: *qwen-tts + name: "nvidia-l4t-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-qwen-tts" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-qwen-tts +- !!merge <<: *qwen-tts + name: "nvidia-l4t-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-qwen-tts" + mirrors: + - localai/localai-backends:master-nvidia-l4t-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda13-nvidia-l4t-arm64-qwen-tts" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts +- !!merge <<: *qwen-tts + name: "cuda13-nvidia-l4t-arm64-qwen-tts-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts" + mirrors: + - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts ## pocket-tts - !!merge <<: *pocket-tts name: "pocket-tts-development" diff --git a/backend/python/bark/requirements-intel.txt b/backend/python/bark/requirements-intel.txt index ee3c20240..06cf55c5a 100644 --- a/backend/python/bark/requirements-intel.txt +++ b/backend/python/bark/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu -torch==2.3.1+cxx11.abi -torchaudio==2.3.1+cxx11.abi -oneccl_bind_pt==2.3.100+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchaudio optimum[openvino] setuptools transformers diff --git a/backend/python/chatterbox/requirements-intel.txt b/backend/python/chatterbox/requirements-intel.txt index cb88cbc27..8aef8f6cd 100644 --- a/backend/python/chatterbox/requirements-intel.txt +++ b/backend/python/chatterbox/requirements-intel.txt @@ -1,7 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.3.1+cxx11.abi -torchaudio==2.3.1+cxx11.abi +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchaudio transformers numpy>=1.24.0,<1.26.0 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 diff --git a/backend/python/common/libbackend.sh b/backend/python/common/libbackend.sh index 7956b3c10..c923c12cf 100644 --- a/backend/python/common/libbackend.sh +++ b/backend/python/common/libbackend.sh @@ -398,7 +398,7 @@ function runProtogen() { # NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index. # you may want to add the following line to a requirements-intel.txt if you use one: # -# --index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +# --index-url https://download.pytorch.org/whl/xpu # # If you need to add extra flags into the pip install command you can do so by setting the variable EXTRA_PIP_INSTALL_FLAGS # before calling installRequirements. For example: diff --git a/backend/python/common/template/requirements-intel.txt b/backend/python/common/template/requirements-intel.txt index 53393f6a2..bcc9c9095 100644 --- a/backend/python/common/template/requirements-intel.txt +++ b/backend/python/common/template/requirements-intel.txt @@ -1,5 +1,4 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu +--extra-index-url https://download.pytorch.org/whl/xpu torch==2.8.0 oneccl_bind_pt==2.8.0+xpu optimum[openvino] \ No newline at end of file diff --git a/backend/python/coqui/requirements-intel.txt b/backend/python/coqui/requirements-intel.txt index c45ce1660..3a854b875 100644 --- a/backend/python/coqui/requirements-intel.txt +++ b/backend/python/coqui/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.3.1+cxx11.abi -torchaudio==2.3.1+cxx11.abi -oneccl_bind_pt==2.3.100+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch==2.8.0+xpu +torchaudio==2.8.0+xpu optimum[openvino] setuptools transformers==4.48.3 diff --git a/backend/python/diffusers/requirements-intel.txt b/backend/python/diffusers/requirements-intel.txt index fec4d9df7..e0fa69fb0 100644 --- a/backend/python/diffusers/requirements-intel.txt +++ b/backend/python/diffusers/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.5.1+cxx11.abi -torchvision==0.20.1+cxx11.abi -oneccl_bind_pt==2.8.0+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchvision optimum[openvino] setuptools git+https://github.com/huggingface/diffusers diff --git a/backend/python/faster-whisper/requirements-intel.txt b/backend/python/faster-whisper/requirements-intel.txt index 417aa0b47..aa4e120d4 100644 --- a/backend/python/faster-whisper/requirements-intel.txt +++ b/backend/python/faster-whisper/requirements-intel.txt @@ -1,6 +1,4 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.3.1+cxx11.abi -oneccl_bind_pt==2.3.100+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch optimum[openvino] faster-whisper \ No newline at end of file diff --git a/backend/python/kokoro/requirements-intel.txt b/backend/python/kokoro/requirements-intel.txt index c497efd83..54a596e45 100644 --- a/backend/python/kokoro/requirements-intel.txt +++ b/backend/python/kokoro/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu -torch==2.5.1+cxx11.abi -oneccl_bind_pt==2.8.0+xpu -torchaudio==2.5.1+cxx11.abi +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchaudio optimum[openvino] setuptools transformers==4.48.3 diff --git a/backend/python/pocket-tts/requirements-intel.txt b/backend/python/pocket-tts/requirements-intel.txt index 3bb61cb73..8abae6f85 100644 --- a/backend/python/pocket-tts/requirements-intel.txt +++ b/backend/python/pocket-tts/requirements-intel.txt @@ -1,4 +1,4 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +--extra-index-url https://download.pytorch.org/whl/xpu pocket-tts scipy -torch==2.5.1+cxx11.abi +torch \ No newline at end of file diff --git a/backend/python/qwen-tts/Makefile b/backend/python/qwen-tts/Makefile new file mode 100644 index 000000000..e2c941b80 --- /dev/null +++ b/backend/python/qwen-tts/Makefile @@ -0,0 +1,23 @@ +.PHONY: qwen-tts +qwen-tts: + bash install.sh + +.PHONY: run +run: qwen-tts + @echo "Running qwen-tts..." + bash run.sh + @echo "qwen-tts run." + +.PHONY: test +test: qwen-tts + @echo "Testing qwen-tts..." + bash test.sh + @echo "qwen-tts tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/qwen-tts/backend.py b/backend/python/qwen-tts/backend.py new file mode 100644 index 000000000..57fbc222b --- /dev/null +++ b/backend/python/qwen-tts/backend.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +""" +This is an extra gRPC server of LocalAI for Qwen3-TTS +""" +from concurrent import futures +import time +import argparse +import signal +import sys +import os +import copy +import traceback +from pathlib import Path +import backend_pb2 +import backend_pb2_grpc +import torch +import soundfile as sf +from qwen_tts import Qwen3TTSModel + +import grpc + +def is_float(s): + """Check if a string can be converted to float.""" + try: + float(s) + return True + except ValueError: + return False + +def is_int(s): + """Check if a string can be converted to int.""" + try: + int(s) + return True + except ValueError: + return False + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + BackendServicer is the class that implements the gRPC service + """ + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + # Get device + if torch.cuda.is_available(): + print("CUDA is available", file=sys.stderr) + device = "cuda" + else: + print("CUDA is not available", file=sys.stderr) + device = "cpu" + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + if mps_available: + device = "mps" + if not torch.cuda.is_available() and request.CUDA: + return backend_pb2.Result(success=False, message="CUDA is not available") + + # Normalize potential 'mpx' typo to 'mps' + if device == "mpx": + print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) + device = "mps" + + # Validate mps availability if requested + if device == "mps" and not torch.backends.mps.is_available(): + print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) + device = "cpu" + + self.device = device + self._torch_device = torch.device(device) + + options = request.Options + + # empty dict + self.options = {} + + # The options are a list of strings in this form optname:optvalue + # We are storing all the options in a dict so we can use it later when + # generating the audio + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) # Split only on first colon + # if value is a number, convert it to the appropriate type + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + self.options[key] = value + + # Get model path from request + model_path = request.Model + if not model_path: + model_path = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" + + # Determine model type from model path or options + self.model_type = self.options.get("model_type", None) + if not self.model_type: + if "CustomVoice" in model_path: + self.model_type = "CustomVoice" + elif "VoiceDesign" in model_path: + self.model_type = "VoiceDesign" + elif "Base" in model_path or "0.6B" in model_path or "1.7B" in model_path: + self.model_type = "Base" # VoiceClone model + else: + # Default to CustomVoice + self.model_type = "CustomVoice" + + # Cache for voice clone prompts + self._voice_clone_cache = {} + + # Store AudioPath, ModelFile, and ModelPath from LoadModel request + # These are used later in TTS for VoiceClone mode + self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None + self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None + self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None + + # Decide dtype & attention implementation + if self.device == "mps": + load_dtype = torch.float32 # MPS requires float32 + device_map = None + attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS + elif self.device == "cuda": + load_dtype = torch.bfloat16 + device_map = "cuda" + attn_impl_primary = "flash_attention_2" + else: # cpu + load_dtype = torch.float32 + device_map = "cpu" + attn_impl_primary = "sdpa" + + print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}", file=sys.stderr) + print(f"Loading model from: {model_path}", file=sys.stderr) + + # Load model with device-specific logic + # Common parameters for all devices + load_kwargs = { + "dtype": load_dtype, + "attn_implementation": attn_impl_primary, + "trust_remote_code": True, # Required for qwen-tts models + } + + try: + if self.device == "mps": + load_kwargs["device_map"] = None # load then move + self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) + self.model.to("mps") + elif self.device == "cuda": + load_kwargs["device_map"] = device_map + self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) + else: # cpu + load_kwargs["device_map"] = device_map + self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) + except Exception as e: + error_msg = str(e) + print(f"[ERROR] Loading model: {type(e).__name__}: {error_msg}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + # Check if it's a missing feature extractor/tokenizer error + if "speech_tokenizer" in error_msg or "preprocessor_config.json" in error_msg or "feature extractor" in error_msg.lower(): + print("\n[ERROR] Model files appear to be incomplete. This usually means:", file=sys.stderr) + print(" 1. The model download was interrupted or incomplete", file=sys.stderr) + print(" 2. The model cache is corrupted", file=sys.stderr) + print("\nTo fix this, try:", file=sys.stderr) + print(f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*", file=sys.stderr) + print(" Then re-run to trigger a fresh download.", file=sys.stderr) + print("\nAlternatively, try using a different model variant:", file=sys.stderr) + print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr) + print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr) + print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr) + + if attn_impl_primary == 'flash_attention_2': + print("\nTrying to use SDPA instead of flash_attention_2...", file=sys.stderr) + load_kwargs["attn_implementation"] = 'sdpa' + try: + if self.device == "mps": + load_kwargs["device_map"] = None + self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) + self.model.to("mps") + else: + load_kwargs["device_map"] = (self.device if self.device in ("cuda", "cpu") else None) + self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) + except Exception as e2: + print(f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + raise e2 + else: + raise e + + print(f"Model loaded successfully: {model_path}", file=sys.stderr) + + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def _detect_mode(self, request): + """Detect which mode to use based on request parameters.""" + # Priority: VoiceClone > VoiceDesign > CustomVoice + + # model_type explicitly set + if self.model_type == "CustomVoice": + return "CustomVoice" + if self.model_type == "VoiceClone": + return "VoiceClone" + if self.model_type == "VoiceDesign": + return "VoiceDesign" + + # VoiceClone: AudioPath is provided (from LoadModel, stored in self.audio_path) + if self.audio_path: + return "VoiceClone" + + # VoiceDesign: instruct option is provided + if "instruct" in self.options and self.options["instruct"]: + return "VoiceDesign" + + # Default to CustomVoice + return "CustomVoice" + + def _get_ref_audio_path(self, request): + """Get reference audio path from stored AudioPath (from LoadModel).""" + if not self.audio_path: + return None + + # If absolute path, use as-is + if os.path.isabs(self.audio_path): + return self.audio_path + + # Try relative to ModelFile + if self.model_file: + model_file_base = os.path.dirname(self.model_file) + ref_path = os.path.join(model_file_base, self.audio_path) + if os.path.exists(ref_path): + return ref_path + + # Try relative to ModelPath + if self.model_path: + ref_path = os.path.join(self.model_path, self.audio_path) + if os.path.exists(ref_path): + return ref_path + + # Return as-is (might be URL or base64) + return self.audio_path + + def _get_voice_clone_prompt(self, request, ref_audio, ref_text): + """Get or create voice clone prompt, with caching.""" + cache_key = f"{ref_audio}:{ref_text}" + + if cache_key not in self._voice_clone_cache: + print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr) + try: + prompt_items = self.model.create_voice_clone_prompt( + ref_audio=ref_audio, + ref_text=ref_text, + x_vector_only_mode=self.options.get("x_vector_only_mode", False), + ) + self._voice_clone_cache[cache_key] = prompt_items + except Exception as e: + print(f"Error creating voice clone prompt: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return self._voice_clone_cache[cache_key] + + def TTS(self, request, context): + try: + # Check if dst is provided + if not request.dst: + return backend_pb2.Result( + success=False, + message="dst (output path) is required" + ) + + # Prepare text + text = request.text.strip() + if not text: + return backend_pb2.Result( + success=False, + message="Text is empty" + ) + + # Get language (auto-detect if not provided) + language = request.language if hasattr(request, 'language') and request.language else None + if not language or language == "": + language = "Auto" # Auto-detect language + + # Detect mode + mode = self._detect_mode(request) + print(f"Detected mode: {mode}", file=sys.stderr) + + # Get generation parameters from options + max_new_tokens = self.options.get("max_new_tokens", None) + top_p = self.options.get("top_p", None) + temperature = self.options.get("temperature", None) + do_sample = self.options.get("do_sample", None) + + # Prepare generation kwargs + generation_kwargs = {} + if max_new_tokens is not None: + generation_kwargs["max_new_tokens"] = max_new_tokens + if top_p is not None: + generation_kwargs["top_p"] = top_p + if temperature is not None: + generation_kwargs["temperature"] = temperature + if do_sample is not None: + generation_kwargs["do_sample"] = do_sample + + instruct = self.options.get("instruct", "") + if instruct is not None and instruct != "": + generation_kwargs["instruct"] = instruct + + # Generate audio based on mode + if mode == "VoiceClone": + # VoiceClone mode + ref_audio = self._get_ref_audio_path(request) + if not ref_audio: + return backend_pb2.Result( + success=False, + message="AudioPath is required for VoiceClone mode" + ) + + ref_text = self.options.get("ref_text", None) + if not ref_text: + # Try to get from request if available + if hasattr(request, 'ref_text') and request.ref_text: + ref_text = request.ref_text + else: + # x_vector_only_mode doesn't require ref_text + if not self.options.get("x_vector_only_mode", False): + return backend_pb2.Result( + success=False, + message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)" + ) + + # Check if we should use cached prompt + use_cached_prompt = self.options.get("use_cached_prompt", True) + voice_clone_prompt = None + + if use_cached_prompt: + voice_clone_prompt = self._get_voice_clone_prompt(request, ref_audio, ref_text) + if voice_clone_prompt is None: + return backend_pb2.Result( + success=False, + message="Failed to create voice clone prompt" + ) + + if voice_clone_prompt: + # Use cached prompt + wavs, sr = self.model.generate_voice_clone( + text=text, + language=language, + voice_clone_prompt=voice_clone_prompt, + **generation_kwargs + ) + else: + # Create prompt on-the-fly + wavs, sr = self.model.generate_voice_clone( + text=text, + language=language, + ref_audio=ref_audio, + ref_text=ref_text, + x_vector_only_mode=self.options.get("x_vector_only_mode", False), + **generation_kwargs + ) + + elif mode == "VoiceDesign": + # VoiceDesign mode + if not instruct: + return backend_pb2.Result( + success=False, + message="instruct option is required for VoiceDesign mode" + ) + + wavs, sr = self.model.generate_voice_design( + text=text, + language=language, + instruct=instruct, + **generation_kwargs + ) + + else: + # CustomVoice mode (default) + speaker = request.voice if request.voice else None + if not speaker: + # Try to get from options + speaker = self.options.get("speaker", None) + if not speaker: + # Use default speaker + speaker = "Vivian" + print(f"No speaker specified, using default: {speaker}", file=sys.stderr) + + # Validate speaker if model supports it + if hasattr(self.model, 'get_supported_speakers'): + try: + supported_speakers = self.model.get_supported_speakers() + if speaker not in supported_speakers: + print(f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}", file=sys.stderr) + # Try to find a close match (case-insensitive) + speaker_lower = speaker.lower() + for sup_speaker in supported_speakers: + if sup_speaker.lower() == speaker_lower: + speaker = sup_speaker + print(f"Using matched speaker: {speaker}", file=sys.stderr) + break + except Exception as e: + print(f"Warning: Could not get supported speakers: {e}", file=sys.stderr) + + wavs, sr = self.model.generate_custom_voice( + text=text, + language=language, + speaker=speaker, + **generation_kwargs + ) + + # Save output + if wavs is not None and len(wavs) > 0: + # wavs is a list, take first element + audio_data = wavs[0] if isinstance(wavs, list) else wavs + sf.write(request.dst, audio_data, sr) + print(f"Saved output to {request.dst}", file=sys.stderr) + else: + return backend_pb2.Result( + success=False, + message="No audio output generated" + ) + + except Exception as err: + print(f"Error in TTS: {err}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + + return backend_pb2.Result(success=True) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB + ]) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/qwen-tts/install.sh b/backend/python/qwen-tts/install.sh new file mode 100755 index 000000000..b7d487873 --- /dev/null +++ b/backend/python/qwen-tts/install.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/qwen-tts/requirements-cpu.txt b/backend/python/qwen-tts/requirements-cpu.txt new file mode 100644 index 000000000..f7989d2e8 --- /dev/null +++ b/backend/python/qwen-tts/requirements-cpu.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-cublas12-after.txt b/backend/python/qwen-tts/requirements-cublas12-after.txt new file mode 100644 index 000000000..7bfe8efeb --- /dev/null +++ b/backend/python/qwen-tts/requirements-cublas12-after.txt @@ -0,0 +1 @@ +flash-attn \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-cublas12.txt b/backend/python/qwen-tts/requirements-cublas12.txt new file mode 100644 index 000000000..95afbbcba --- /dev/null +++ b/backend/python/qwen-tts/requirements-cublas12.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-cublas13.txt b/backend/python/qwen-tts/requirements-cublas13.txt new file mode 100644 index 000000000..4b5a053d1 --- /dev/null +++ b/backend/python/qwen-tts/requirements-cublas13.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-hipblas.txt b/backend/python/qwen-tts/requirements-hipblas.txt new file mode 100644 index 000000000..d8a3e3616 --- /dev/null +++ b/backend/python/qwen-tts/requirements-hipblas.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.7.1+rocm6.3 +torchaudio==2.7.1+rocm6.3 +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-intel-after.txt b/backend/python/qwen-tts/requirements-intel-after.txt new file mode 100644 index 000000000..7bfe8efeb --- /dev/null +++ b/backend/python/qwen-tts/requirements-intel-after.txt @@ -0,0 +1 @@ +flash-attn \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-intel.txt b/backend/python/qwen-tts/requirements-intel.txt new file mode 100644 index 000000000..58fa82307 --- /dev/null +++ b/backend/python/qwen-tts/requirements-intel.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-l4t12.txt b/backend/python/qwen-tts/requirements-l4t12.txt new file mode 100644 index 000000000..74b27915f --- /dev/null +++ b/backend/python/qwen-tts/requirements-l4t12.txt @@ -0,0 +1,5 @@ +--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-l4t13.txt b/backend/python/qwen-tts/requirements-l4t13.txt new file mode 100644 index 000000000..4b5a053d1 --- /dev/null +++ b/backend/python/qwen-tts/requirements-l4t13.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch +torchaudio +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements-mps.txt b/backend/python/qwen-tts/requirements-mps.txt new file mode 100644 index 000000000..fcf446ba7 --- /dev/null +++ b/backend/python/qwen-tts/requirements-mps.txt @@ -0,0 +1,4 @@ +torch==2.7.1 +torchaudio==0.22.1 +qwen-tts +sox \ No newline at end of file diff --git a/backend/python/qwen-tts/requirements.txt b/backend/python/qwen-tts/requirements.txt new file mode 100644 index 000000000..ef46fb5ba --- /dev/null +++ b/backend/python/qwen-tts/requirements.txt @@ -0,0 +1,6 @@ +grpcio==1.71.0 +protobuf +certifi +packaging==24.1 +soundfile +setuptools \ No newline at end of file diff --git a/backend/python/qwen-tts/run.sh b/backend/python/qwen-tts/run.sh new file mode 100755 index 000000000..eae121f37 --- /dev/null +++ b/backend/python/qwen-tts/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/qwen-tts/test.py b/backend/python/qwen-tts/test.py new file mode 100644 index 000000000..8a066d3cd --- /dev/null +++ b/backend/python/qwen-tts/test.py @@ -0,0 +1,98 @@ +""" +A test script to test the gRPC service +""" +import unittest +import subprocess +import time +import os +import sys +import tempfile +import threading +import backend_pb2 +import backend_pb2_grpc + +import grpc + + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service + """ + def setUp(self): + """ + This method sets up the gRPC service by starting the server + """ + self.service = subprocess.Popen( + ["python3", "backend.py", "--addr", "localhost:50051"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + time.sleep(5) + + def tearDown(self) -> None: + """ + This method tears down the gRPC service by terminating the server + """ + self.service.terminate() + try: + stdout, stderr = self.service.communicate(timeout=5) + # Output should already be printed by threads, but print any remaining + if stdout: + print("=== REMAINING STDOUT ===") + print(stdout) + if stderr: + print("=== REMAINING STDERR ===") + print(stderr) + except subprocess.TimeoutExpired: + self.service.kill() + stdout, stderr = self.service.communicate() + if stdout: + print("=== REMAINING STDOUT ===") + print(stdout) + if stderr: + print("=== REMAINING STDERR ===") + print(stderr) + + def test_tts(self): + """ + This method tests if the TTS generation works successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + # Allow up to 10 minutes for model download on first run + response = stub.LoadModel( + backend_pb2.ModelOptions(Model="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"), + timeout=600.0 + ) + self.assertTrue(response.success) + + # Create temporary output file + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: + output_path = tmp_file.name + + tts_request = backend_pb2.TTSRequest( + text="Hello, this is a test of the qwen-tts backend.", + voice="Vivian", + dst=output_path + ) + # Allow up to 2 minutes for TTS generation + tts_response = stub.TTS(tts_request, timeout=120.0) + self.assertIsNotNone(tts_response) + self.assertTrue(tts_response.success) + + # Verify output file exists and is not empty + self.assertTrue(os.path.exists(output_path)) + self.assertGreater(os.path.getsize(output_path), 0) + + # Cleanup + os.unlink(output_path) + except Exception as err: + print(f"Exception: {err}", file=sys.stderr) + # Give threads a moment to flush any remaining output + time.sleep(1) + self.fail("TTS service failed") + finally: + self.tearDown() diff --git a/backend/python/qwen-tts/test.sh b/backend/python/qwen-tts/test.sh new file mode 100755 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/qwen-tts/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/backend/python/rerankers/requirements-intel.txt b/backend/python/rerankers/requirements-intel.txt index 820dd8422..cfa7f0f93 100644 --- a/backend/python/rerankers/requirements-intel.txt +++ b/backend/python/rerankers/requirements-intel.txt @@ -1,9 +1,7 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu +--extra-index-url https://download.pytorch.org/whl/xpu transformers accelerate -torch==2.3.1+cxx11.abi -oneccl_bind_pt==2.8.0+xpu +torch rerankers[transformers] optimum[openvino] setuptools \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-intel.txt b/backend/python/rfdetr/requirements-intel.txt index 55fcbb318..a18ee4334 100644 --- a/backend/python/rfdetr/requirements-intel.txt +++ b/backend/python/rfdetr/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.3.1+cxx11.abi -torchvision==0.18.1+cxx11.abi -oneccl_bind_pt==2.3.100+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchvision optimum[openvino] setuptools rfdetr diff --git a/backend/python/transformers/requirements-intel.txt b/backend/python/transformers/requirements-intel.txt index 836861246..8d856e0be 100644 --- a/backend/python/transformers/requirements-intel.txt +++ b/backend/python/transformers/requirements-intel.txt @@ -1,12 +1,9 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.5.1+cxx11.abi -oneccl_bind_pt==2.8.0+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch optimum[openvino] llvmlite==0.43.0 numba==0.60.0 transformers -intel-extension-for-transformers bitsandbytes outetts sentence-transformers==5.2.0 diff --git a/backend/python/vibevoice/requirements-intel.txt b/backend/python/vibevoice/requirements-intel.txt index e040ef6b5..af061781b 100644 --- a/backend/python/vibevoice/requirements-intel.txt +++ b/backend/python/vibevoice/requirements-intel.txt @@ -1,8 +1,6 @@ ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.3.110+xpu -torch==2.5.1+cxx11.abi -torchvision==0.20.1+cxx11.abi -oneccl_bind_pt==2.8.0+xpu +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchvision optimum[openvino] setuptools git+https://github.com/huggingface/diffusers diff --git a/backend/python/vllm/requirements-intel.txt b/backend/python/vllm/requirements-intel.txt index a5a176f2f..1e8d8672a 100644 --- a/backend/python/vllm/requirements-intel.txt +++ b/backend/python/vllm/requirements-intel.txt @@ -1,10 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/xpu ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.7.10+xpu accelerate -torch==2.7.0+xpu +torch transformers optimum[openvino] setuptools -bitsandbytes -oneccl_bind_pt==2.7.0+xpu \ No newline at end of file +bitsandbytes \ No newline at end of file diff --git a/docs/content/features/text-to-audio.md b/docs/content/features/text-to-audio.md index 3d1794307..132c56429 100644 --- a/docs/content/features/text-to-audio.md +++ b/docs/content/features/text-to-audio.md @@ -215,50 +215,90 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ }' | aplay ``` -### Vall-E-X +### Qwen3-TTS -[VALL-E-X](https://github.com/Plachtaa/VALL-E-X) is an open source implementation of Microsoft's VALL-E X zero-shot TTS model. +[Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) is a high-quality text-to-speech model that supports three modes: custom voice (predefined speakers), voice design (natural language instructions), and voice cloning (from reference audio). #### Setup -The backend will automatically download the required files in order to run the model. - -This is an extra backend - in the container is already available and there is nothing to do for the setup. If you are building manually, you need to install Vall-E-X manually first. +Install the `qwen-tts` model in the Model gallery or run `local-ai run models install qwen-tts`. #### Usage -Use the tts endpoint by specifying the vall-e-x backend: +Use the tts endpoint by specifying the qwen-tts backend: ``` curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ - "backend": "vall-e-x", - "input":"Hello!" + "model": "qwen-tts", + "input":"Hello world, this is a test." }' | aplay ``` -#### Voice cloning +#### Custom Voice Mode -In order to use voice cloning capabilities you must create a `YAML` configuration file to setup a model: +Qwen3-TTS supports predefined speakers. You can specify a speaker using the `voice` parameter: ```yaml -name: cloned-voice -backend: vall-e-x +name: qwen-tts +backend: qwen-tts parameters: - model: "cloned-voice" + model: Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice tts: - vall-e: - # The path to the audio file to be cloned - # relative to the models directory - # Max 15s - audio_path: "audio-sample.wav" + voice: "Vivian" # Available speakers: Vivian, Serena, Uncle_Fu, Dylan, Eric, Ryan, Aiden, Ono_Anna, Sohee ``` -Then you can specify the model name in the requests: +Available speakers: +- **Chinese**: Vivian, Serena, Uncle_Fu, Dylan, Eric +- **English**: Ryan, Aiden +- **Japanese**: Ono_Anna +- **Korean**: Sohee + +#### Voice Design Mode + +Voice Design allows you to create custom voices using natural language instructions. Configure the model with an `instruct` option: + +```yaml +name: qwen-tts-design +backend: qwen-tts +parameters: + model: Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign +options: + - "instruct:体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。" +``` + +Then use the model: ``` curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ - "model": "cloned-voice", - "input":"Hello!" + "model": "qwen-tts-design", + "input":"Hello world, this is a test." + }' | aplay +``` + +#### Voice Clone Mode + +Voice Clone allows you to clone a voice from reference audio. Configure the model with an `AudioPath` and optional `ref_text`: + +```yaml +name: qwen-tts-clone +backend: qwen-tts +parameters: + model: Qwen/Qwen3-TTS-12Hz-1.7B-Base +tts: + audio_path: "path/to/reference_audio.wav" # Reference audio file +options: + - "ref_text:This is the transcript of the reference audio." + - "x_vector_only_mode:false" # Set to true to use only speaker embedding (ref_text not required) +``` + +You can also use URLs or base64 strings for the reference audio. The backend automatically detects the mode based on available parameters (AudioPath → VoiceClone, instruct option → VoiceDesign, voice parameter → CustomVoice). + +Then use the model: + +``` +curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ + "model": "qwen-tts-clone", + "input":"Hello world, this is a test." }' | aplay ```