diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index b9c658df8..0b2f653b4 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -52,6 +52,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' + - build-type: 'l4t' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-ace-step' + runs-on: 'ubuntu-24.04-arm' + base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" + skip-drivers: 'true' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2204' - build-type: '' cuda-major-version: "" cuda-minor-version: "" @@ -104,6 +117,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-cpu-ace-step' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" @@ -287,6 +313,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "8" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-ace-step' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" @@ -600,6 +639,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-ace-step' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" @@ -665,6 +717,19 @@ jobs: backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" + - build-type: 'l4t' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-cuda-13-arm64-ace-step' + runs-on: 'ubuntu-24.04-arm' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + ubuntu-version: '2404' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" @@ -952,6 +1017,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'hipblas' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-rocm-hipblas-ace-step' + runs-on: 'arc-runner-set' + base-image: "rocm/dev-ubuntu-24.04:6.4.4" + skip-drivers: 'false' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' # ROCm additional backends - build-type: 'hipblas' cuda-major-version: "" @@ -1149,6 +1227,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'intel' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-intel-ace-step' + runs-on: 'ubuntu-latest' + base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" + skip-drivers: 'false' + backend: "ace-step" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" @@ -1791,6 +1882,9 @@ jobs: - backend: "diffusers" tag-suffix: "-metal-darwin-arm64-diffusers" build-type: "mps" + - backend: "ace-step" + tag-suffix: "-metal-darwin-arm64-ace-step" + build-type: "mps" - backend: "mlx" tag-suffix: "-metal-darwin-arm64-mlx" build-type: "mps" diff --git a/Makefile b/Makefile index dba46f84f..92ef60b59 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/qwen-asr backends/voxcpm backends/whisperx +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/qwen-asr backends/voxcpm backends/whisperx backends/ace-step GOCMD=go GOTEST=$(GOCMD) test @@ -320,6 +320,7 @@ prepare-test-extra: protogen-python $(MAKE) -C backend/python/qwen-asr $(MAKE) -C backend/python/voxcpm $(MAKE) -C backend/python/whisperx + $(MAKE) -C backend/python/ace-step test-extra: prepare-test-extra $(MAKE) -C backend/python/transformers test @@ -335,6 +336,7 @@ test-extra: prepare-test-extra $(MAKE) -C backend/python/qwen-asr test $(MAKE) -C backend/python/voxcpm test $(MAKE) -C backend/python/whisperx test + $(MAKE) -C backend/python/ace-step test DOCKER_IMAGE?=local-ai DOCKER_AIO_IMAGE?=local-ai-aio @@ -520,12 +522,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_ASR))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) +$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-qwen-asr docker-build-voxcpm docker-build-whisperx +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-qwen-asr docker-build-voxcpm docker-build-whisperx docker-build-ace-step ######################################################## ### Mock Backend for E2E Tests diff --git a/backend/backend.proto b/backend/backend.proto index 2ba04c883..07cd095d2 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -365,6 +365,14 @@ message SoundGenerationRequest { optional bool sample = 6; optional string src = 7; optional int32 src_divisor = 8; + optional bool think = 9; + optional string caption = 10; + optional string lyrics = 11; + optional int32 bpm = 12; + optional string keyscale = 13; + optional string language = 14; + optional string timesignature = 15; + optional bool instrumental = 17; } message TokenizationResponse { diff --git a/backend/index.yaml b/backend/index.yaml index 4745fc469..a0fa9be06 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -296,6 +296,40 @@ nvidia-cuda-12: "cuda12-diffusers" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-diffusers" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-diffusers" +- &ace-step + name: "ace-step" + description: | + ACE-Step 1.5 is an open-source music generation model. It supports simple mode (natural language description) and advanced mode (caption, lyrics, think, bpm, keyscale, etc.). Uses in-process acestep (LLMHandler for metadata, DiT for audio). + urls: + - https://github.com/ace-step/ACE-Step-1.5 + tags: + - music-generation + - sound-generation + alias: "ace-step" + capabilities: + nvidia: "cuda12-ace-step" + intel: "intel-ace-step" + amd: "rocm-ace-step" + nvidia-l4t: "nvidia-l4t-ace-step" + metal: "metal-ace-step" + default: "cpu-ace-step" + nvidia-cuda-13: "cuda13-ace-step" + nvidia-cuda-12: "cuda12-ace-step" + nvidia-l4t-cuda-12: "nvidia-l4t-ace-step" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ace-step" +- !!merge <<: *ace-step + name: "ace-step-development" + capabilities: + nvidia: "cuda12-ace-step-development" + intel: "intel-ace-step-development" + amd: "rocm-ace-step-development" + nvidia-l4t: "nvidia-l4t-ace-step-development" + metal: "metal-ace-step-development" + default: "cpu-ace-step-development" + nvidia-cuda-13: "cuda13-ace-step-development" + nvidia-cuda-12: "cuda12-ace-step-development" + nvidia-l4t-cuda-12: "nvidia-l4t-ace-step-development" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ace-step-development" - &faster-whisper icon: https://avatars.githubusercontent.com/u/1520500?s=200&v=4 description: | @@ -1593,6 +1627,87 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-diffusers" mirrors: - localai/localai-backends:master-metal-darwin-arm64-diffusers +## ace-step +- !!merge <<: *ace-step + name: "cpu-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-ace-step" + mirrors: + - localai/localai-backends:latest-cpu-ace-step +- !!merge <<: *ace-step + name: "cpu-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-ace-step" + mirrors: + - localai/localai-backends:master-cpu-ace-step +- !!merge <<: *ace-step + name: "cuda12-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-ace-step" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-ace-step +- !!merge <<: *ace-step + name: "cuda12-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-ace-step" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-ace-step +- !!merge <<: *ace-step + name: "cuda13-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-ace-step" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-13-ace-step +- !!merge <<: *ace-step + name: "cuda13-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-13-ace-step +- !!merge <<: *ace-step + name: "rocm-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-ace-step" + mirrors: + - localai/localai-backends:latest-gpu-rocm-hipblas-ace-step +- !!merge <<: *ace-step + name: "rocm-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-ace-step" + mirrors: + - localai/localai-backends:master-gpu-rocm-hipblas-ace-step +- !!merge <<: *ace-step + name: "intel-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-ace-step" + mirrors: + - localai/localai-backends:latest-gpu-intel-ace-step +- !!merge <<: *ace-step + name: "intel-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-ace-step" + mirrors: + - localai/localai-backends:master-gpu-intel-ace-step +- !!merge <<: *ace-step + name: "metal-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-ace-step" + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-ace-step +- !!merge <<: *ace-step + name: "metal-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ace-step" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-ace-step +- !!merge <<: *ace-step + name: "nvidia-l4t-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-ace-step" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-ace-step +- !!merge <<: *ace-step + name: "nvidia-l4t-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-ace-step" + mirrors: + - localai/localai-backends:master-nvidia-l4t-ace-step +- !!merge <<: *ace-step + name: "cuda13-nvidia-l4t-arm64-ace-step" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-ace-step" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-ace-step +- !!merge <<: *ace-step + name: "cuda13-nvidia-l4t-arm64-ace-step-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-ace-step" + mirrors: + - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-ace-step ## kokoro - !!merge <<: *kokoro name: "kokoro-development" diff --git a/backend/python/ace-step/Makefile b/backend/python/ace-step/Makefile new file mode 100644 index 000000000..8ad2368ab --- /dev/null +++ b/backend/python/ace-step/Makefile @@ -0,0 +1,16 @@ +.DEFAULT_GOAL := install + +.PHONY: install +install: + bash install.sh + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ + +test: install + bash test.sh diff --git a/backend/python/ace-step/backend.py b/backend/python/ace-step/backend.py new file mode 100644 index 000000000..a7d635a52 --- /dev/null +++ b/backend/python/ace-step/backend.py @@ -0,0 +1,546 @@ +#!/usr/bin/env python3 +""" +LocalAI ACE-Step Backend + +gRPC backend for ACE-Step 1.5 music generation. Aligns with upstream acestep API: +- LoadModel: initializes AceStepHandler (DiT) and LLMHandler, parses Options. +- SoundGeneration: uses create_sample (simple mode), format_sample (optional), then + generate_music from acestep.inference. Writes first output to request.dst. +- Fail hard: no fallback WAV on error; exceptions propagate to gRPC. +""" +from concurrent import futures +import argparse +import shutil +import signal +import sys +import os +import tempfile + +import backend_pb2 +import backend_pb2_grpc +import grpc +from acestep.inference import ( + GenerationParams, + GenerationConfig, + generate_music, + create_sample, + format_sample, +) +from acestep.handler import AceStepHandler +from acestep.llm_inference import LLMHandler + + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) + +# Model name -> HuggingFace/ModelScope repo (from upstream api_server.py) +MODEL_REPO_MAPPING = { + "acestep-v15-turbo": "ACE-Step/Ace-Step1.5", + "acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5", + "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5", + "vae": "ACE-Step/Ace-Step1.5", + "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5", + "acestep-v15-base": "ACE-Step/acestep-v15-base", + "acestep-v15-sft": "ACE-Step/acestep-v15-sft", + "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3", + "acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B", +} +DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" + + +def _can_access_google(timeout=3.0): + """Check if Google is accessible (to choose HuggingFace vs ModelScope).""" + import socket + try: + socket.setdefaulttimeout(timeout) + socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect(("www.google.com", 443)) + return True + except (socket.timeout, socket.error, OSError): + return False + + +def _download_from_huggingface(repo_id, local_dir, model_name): + """Download model from HuggingFace Hub.""" + from huggingface_hub import snapshot_download + is_unified = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" + if is_unified: + download_dir = local_dir + print(f"[ace-step] Downloading unified repo {repo_id} to {download_dir}...", file=sys.stderr) + else: + download_dir = os.path.join(local_dir, model_name) + os.makedirs(download_dir, exist_ok=True) + print(f"[ace-step] Downloading {model_name} from {repo_id} to {download_dir}...", file=sys.stderr) + snapshot_download( + repo_id=repo_id, + local_dir=download_dir, + local_dir_use_symlinks=False, + ) + return os.path.join(local_dir, model_name) + + +def _download_from_modelscope(repo_id, local_dir, model_name): + """Download model from ModelScope.""" + from modelscope import snapshot_download + is_unified = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" + if is_unified: + download_dir = local_dir + print(f"[ace-step] Downloading unified repo {repo_id} from ModelScope to {download_dir}...", file=sys.stderr) + else: + download_dir = os.path.join(local_dir, model_name) + os.makedirs(download_dir, exist_ok=True) + print(f"[ace-step] Downloading {model_name} from ModelScope {repo_id} to {download_dir}...", file=sys.stderr) + snapshot_download( + model_id=repo_id, + local_dir=download_dir, + ) + return os.path.join(local_dir, model_name) + + +def _ensure_model_downloaded(model_name, checkpoint_dir): + """ + Ensure model is present; download from HuggingFace or ModelScope if missing. + model_name: e.g. "acestep-v15-turbo", "vae", "acestep-5Hz-lm-0.6B" + checkpoint_dir: directory that will contain model_name as a subdir. + Returns path to the model directory. + """ + return None + if not model_name or not checkpoint_dir: + return None + model_path = os.path.join(checkpoint_dir, model_name) + if os.path.exists(model_path) and os.listdir(model_path): + print(f"[ace-step] Model {model_name} already at {model_path}", file=sys.stderr) + return model_path + repo_id = MODEL_REPO_MAPPING.get(model_name, DEFAULT_REPO_ID) + print(f"[ace-step] Model {model_name} not found, downloading...", file=sys.stderr) + use_hf = _can_access_google() + if use_hf: + try: + return _download_from_huggingface(repo_id, checkpoint_dir, model_name) + except Exception as e: + print(f"[ace-step] HuggingFace download failed: {e}, trying ModelScope", file=sys.stderr) + return _download_from_modelscope(repo_id, checkpoint_dir, model_name) + else: + try: + return _download_from_modelscope(repo_id, checkpoint_dir, model_name) + except Exception as e: + print(f"[ace-step] ModelScope download failed: {e}, trying HuggingFace", file=sys.stderr) + return _download_from_huggingface(repo_id, checkpoint_dir, model_name) + + +def _is_float(s): + try: + float(s) + return True + except (ValueError, TypeError): + return False + + +def _is_int(s): + try: + int(s) + return True + except (ValueError, TypeError): + return False + + +def _parse_timesteps(s): + if s is None or (isinstance(s, str) and not s.strip()): + return None + if isinstance(s, (list, tuple)): + return [float(x) for x in s] + try: + return [float(x.strip()) for x in str(s).split(",") if x.strip()] + except (ValueError, TypeError): + return None + + +def _parse_options(opts_list): + """Parse repeated 'key:value' options into a dict. Coerce numeric and bool.""" + out = {} + for opt in opts_list or []: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + key = key.strip() + value = value.strip() + if _is_int(value): + out[key] = int(value) + elif _is_float(value): + out[key] = float(value) + elif value.lower() in ("true", "false"): + out[key] = value.lower() == "true" + else: + out[key] = value + return out + + +def _generate_audio_sync(servicer, payload, dst_path): + """ + Run full ACE-Step pipeline using acestep.inference: + - If sample_mode/sample_query: create_sample() for caption/lyrics/metadata. + - If use_format and caption/lyrics: format_sample(). + - Build GenerationParams and GenerationConfig, then generate_music(). + Writes the first generated audio to dst_path. Raises on failure. + """ + + opts = servicer.options + dit_handler = servicer.dit_handler + llm_handler = servicer.llm_handler + + for key, value in opts.items(): + if key not in payload: + payload[key] = value + + def _opt(name, default): + return opts.get(name, default) + + lm_temperature = _opt("temperature", 0.85) + lm_cfg_scale = _opt("lm_cfg_scale", _opt("cfg_scale", 2.0)) + lm_top_k = opts.get("top_k") + lm_top_p = _opt("top_p", 0.9) + if lm_top_p is not None and lm_top_p >= 1.0: + lm_top_p = None + inference_steps = _opt("inference_steps", 8) + guidance_scale = _opt("guidance_scale", 7.0) + batch_size = max(1, int(_opt("batch_size", 1))) + + use_simple = bool(payload.get("sample_query") or payload.get("text")) + sample_mode = use_simple and (payload.get("thinking") or payload.get("sample_mode")) + sample_query = (payload.get("sample_query") or payload.get("text") or "").strip() + use_format = bool(payload.get("use_format")) + caption = (payload.get("prompt") or payload.get("caption") or "").strip() + lyrics = (payload.get("lyrics") or "").strip() + vocal_language = (payload.get("vocal_language") or "en").strip() + instrumental = bool(payload.get("instrumental")) + bpm = payload.get("bpm") + key_scale = (payload.get("key_scale") or "").strip() + time_signature = (payload.get("time_signature") or "").strip() + audio_duration = payload.get("audio_duration") + if audio_duration is not None: + try: + audio_duration = float(audio_duration) + except (TypeError, ValueError): + audio_duration = None + + if sample_mode and llm_handler and getattr(llm_handler, "llm_initialized", False): + parsed_language = None + if sample_query: + for hint in ("english", "en", "chinese", "zh", "japanese", "ja"): + if hint in sample_query.lower(): + parsed_language = "en" if hint == "english" or hint == "en" else hint + break + vocal_lang = vocal_language if vocal_language and vocal_language != "unknown" else parsed_language + sample_result = create_sample( + llm_handler=llm_handler, + query=sample_query or "NO USER INPUT", + instrumental=instrumental, + vocal_language=vocal_lang, + temperature=lm_temperature, + top_k=lm_top_k, + top_p=lm_top_p, + use_constrained_decoding=True, + ) + if not sample_result.success: + raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}") + caption = sample_result.caption or caption + lyrics = sample_result.lyrics or lyrics + bpm = sample_result.bpm + key_scale = sample_result.keyscale or key_scale + time_signature = sample_result.timesignature or time_signature + if sample_result.duration is not None: + audio_duration = sample_result.duration + if getattr(sample_result, "language", None): + vocal_language = sample_result.language + + if use_format and (caption or lyrics) and llm_handler and getattr(llm_handler, "llm_initialized", False): + user_metadata = {} + if bpm is not None: + user_metadata["bpm"] = bpm + if audio_duration is not None and float(audio_duration) > 0: + user_metadata["duration"] = int(audio_duration) + if key_scale: + user_metadata["keyscale"] = key_scale + if time_signature: + user_metadata["timesignature"] = time_signature + if vocal_language and vocal_language != "unknown": + user_metadata["language"] = vocal_language + format_result = format_sample( + llm_handler=llm_handler, + caption=caption, + lyrics=lyrics, + user_metadata=user_metadata if user_metadata else None, + temperature=lm_temperature, + top_k=lm_top_k, + top_p=lm_top_p, + use_constrained_decoding=True, + ) + if format_result.success: + caption = format_result.caption or caption + lyrics = format_result.lyrics or lyrics + if format_result.duration is not None: + audio_duration = format_result.duration + if format_result.bpm is not None: + bpm = format_result.bpm + if format_result.keyscale: + key_scale = format_result.keyscale + if format_result.timesignature: + time_signature = format_result.timesignature + if getattr(format_result, "language", None): + vocal_language = format_result.language + + thinking = bool(payload.get("thinking")) + use_cot_metas = not sample_mode + params = GenerationParams( + task_type=payload.get("task_type", "text2music"), + instruction=payload.get("instruction", "Fill the audio semantic mask based on the given conditions:"), + reference_audio=payload.get("reference_audio_path"), + src_audio=payload.get("src_audio_path"), + audio_codes=payload.get("audio_code_string", ""), + caption=caption, + lyrics=lyrics, + instrumental=instrumental or (not lyrics or str(lyrics).strip().lower() in ("[inst]", "[instrumental]")), + vocal_language=vocal_language or "unknown", + bpm=bpm, + keyscale=key_scale, + timesignature=time_signature, + duration=float(audio_duration) if audio_duration and float(audio_duration) > 0 else -1.0, + inference_steps=inference_steps, + seed=int(payload.get("seed", -1)), + guidance_scale=guidance_scale, + use_adg=bool(payload.get("use_adg")), + cfg_interval_start=float(payload.get("cfg_interval_start", 0.0)), + cfg_interval_end=float(payload.get("cfg_interval_end", 1.0)), + shift=float(payload.get("shift", 1.0)), + infer_method=(payload.get("infer_method") or "ode").strip(), + timesteps=_parse_timesteps(payload.get("timesteps")), + repainting_start=float(payload.get("repainting_start", 0.0)), + repainting_end=float(payload.get("repainting_end", -1)) if payload.get("repainting_end") is not None else -1, + audio_cover_strength=float(payload.get("audio_cover_strength", 1.0)), + thinking=thinking, + lm_temperature=lm_temperature, + lm_cfg_scale=lm_cfg_scale, + lm_top_k=lm_top_k or 0, + lm_top_p=lm_top_p if lm_top_p is not None and lm_top_p < 1.0 else 0.9, + lm_negative_prompt=payload.get("lm_negative_prompt", "NO USER INPUT"), + use_cot_metas=use_cot_metas, + use_cot_caption=bool(payload.get("use_cot_caption", True)), + use_cot_language=bool(payload.get("use_cot_language", True)), + use_constrained_decoding=True, + ) + + config = GenerationConfig( + batch_size=batch_size, + allow_lm_batch=bool(payload.get("allow_lm_batch", False)), + use_random_seed=bool(payload.get("use_random_seed", True)), + seeds=payload.get("seeds"), + lm_batch_chunk_size=max(1, int(payload.get("lm_batch_chunk_size", 8))), + constrained_decoding_debug=bool(payload.get("constrained_decoding_debug")), + audio_format=(payload.get("audio_format") or "flac").strip() or "flac", + ) + + save_dir = tempfile.mkdtemp(prefix="ace_step_") + try: + result = generate_music( + dit_handler=dit_handler, + llm_handler=llm_handler if (llm_handler and getattr(llm_handler, "llm_initialized", False)) else None, + params=params, + config=config, + save_dir=save_dir, + progress=None, + ) + if not result.success: + raise RuntimeError(result.error or result.status_message or "generate_music failed") + + audios = result.audios or [] + if not audios: + raise RuntimeError("generate_music returned no audio") + + first_path = audios[0].get("path") or "" + if not first_path or not os.path.isfile(first_path): + raise RuntimeError("first generated audio path missing or not a file") + + shutil.copy2(first_path, dst_path) + finally: + try: + shutil.rmtree(save_dir, ignore_errors=True) + except Exception: + pass + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + def __init__(self): + self.model_path = None + self.checkpoint_dir = None + self.project_root = None + self.options = {} + self.dit_handler = None + self.llm_handler = None + + def Health(self, request, context): + return backend_pb2.Reply(message=b"OK") + + def LoadModel(self, request, context): + try: + self.options = _parse_options(list(getattr(request, "Options", []) or [])) + model_path = getattr(request, "ModelPath", None) or "" + model_name = (request.Model or "").strip() + model_file = (getattr(request, "ModelFile", None) or "").strip() + + if model_path and model_name: + self.checkpoint_dir = model_path + self.project_root = os.path.dirname(model_path) + self.model_path = os.path.join(model_path, model_name) + elif model_file: + self.model_path = model_file + self.checkpoint_dir = os.path.dirname(model_file) + self.project_root = os.path.dirname(self.checkpoint_dir) + else: + self.model_path = model_name or "." + self.checkpoint_dir = os.path.dirname(self.model_path) if self.model_path else "." + self.project_root = os.path.dirname(self.checkpoint_dir) if self.checkpoint_dir else "." + + config_path = model_name or os.path.basename(self.model_path.rstrip("/\\")) + + # Auto-download DiT model and VAE if missing (same as upstream) + if config_path: + try: + _ensure_model_downloaded(config_path, self.checkpoint_dir) + except Exception as e: + print(f"[ace-step] Warning: DiT model download failed: {e}", file=sys.stderr) + try: + _ensure_model_downloaded("vae", self.checkpoint_dir) + except Exception as e: + print(f"[ace-step] Warning: VAE download failed: {e}", file=sys.stderr) + + self.dit_handler = AceStepHandler() + device = self.options.get("device", "auto") + use_flash = self.options.get("use_flash_attention", True) + if isinstance(use_flash, str): + use_flash = str(use_flash).lower() in ("1", "true", "yes") + offload = self.options.get("offload_to_cpu", False) + if isinstance(offload, str): + offload = str(offload).lower() in ("1", "true", "yes") + status_msg, ok = self.dit_handler.initialize_service( + project_root=self.project_root, + config_path=config_path, + device=device, + use_flash_attention=use_flash, + compile_model=False, + offload_to_cpu=offload, + offload_dit_to_cpu=bool(self.options.get("offload_dit_to_cpu", False)), + ) + if not ok: + return backend_pb2.Result(success=False, message=f"DiT init failed: {status_msg}") + + self.llm_handler = None + if self.options.get("init_lm", True): + lm_model = self.options.get("lm_model_path", "acestep-5Hz-lm-0.6B") + if lm_model: + try: + _ensure_model_downloaded(lm_model, self.checkpoint_dir) + except Exception as e: + print(f"[ace-step] Warning: LM model download failed: {e}", file=sys.stderr) + self.llm_handler = LLMHandler() + lm_backend = (self.options.get("lm_backend") or "vllm").strip().lower() + if lm_backend not in ("vllm", "pt"): + lm_backend = "vllm" + lm_status, lm_ok = self.llm_handler.initialize( + checkpoint_dir=self.checkpoint_dir, + lm_model_path=lm_model, + backend=lm_backend, + device=device, + offload_to_cpu=offload, + dtype=getattr(self.dit_handler, "dtype", None), + ) + if not lm_ok: + self.llm_handler = None + print(f"[ace-step] LM init failed (optional): {lm_status}", file=sys.stderr) + + print(f"[ace-step] LoadModel: model={self.model_path}, options={list(self.options.keys())}", file=sys.stderr) + return backend_pb2.Result(success=True, message="Model loaded successfully") + except Exception as err: + return backend_pb2.Result(success=False, message=f"LoadModel error: {err}") + + def SoundGeneration(self, request, context): + if not request.dst: + return backend_pb2.Result(success=False, message="request.dst is required") + + use_simple = bool(request.text) + if use_simple: + payload = { + "sample_query": request.text or "", + "sample_mode": True, + "thinking": True, + "vocal_language": request.language or request.GetLanguage() or "en", + "instrumental": request.GetInstrumental() if request.HasField("instrumental") else False, + } + else: + caption = request.caption or request.GetCaption() or request.text + payload = { + "prompt": caption, + "lyrics": request.lyrics or request.GetLyrics() or "", + "thinking": request.GetThink() if request.HasField("think") else False, + "vocal_language": request.language or request.GetLanguage() or "en", + } + if request.HasField("bpm"): + payload["bpm"] = request.GetBpm() + if request.HasField("keyscale") and request.GetKeyscale(): + payload["key_scale"] = request.GetKeyscale() + if request.HasField("timesignature") and request.GetTimesignature(): + payload["time_signature"] = request.GetTimesignature() + if request.HasField("duration") and request.duration: + payload["audio_duration"] = int(request.duration) if request.duration else None + if request.Src: + payload["src_audio_path"] = request.Src + + _generate_audio_sync(self, payload, request.dst) + return backend_pb2.Result(success=True, message="Sound generated successfully") + + def TTS(self, request, context): + if not request.dst: + return backend_pb2.Result(success=False, message="request.dst is required") + payload = { + "sample_query": request.text, + "sample_mode": True, + "thinking": False, + "vocal_language": (request.language if request.language else "") or "en", + "instrumental": False, + } + _generate_audio_sync(self, payload, request.dst) + return backend_pb2.Result(success=True, message="TTS (music fallback) generated successfully") + + +def serve(address): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ("grpc.max_message_length", 50 * 1024 * 1024), + ("grpc.max_send_message_length", 50 * 1024 * 1024), + ("grpc.max_receive_message_length", 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print(f"[ace-step] Server listening on {address}", file=sys.stderr) + + def shutdown(sig, frame): + server.stop(0) + sys.exit(0) + + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + + try: + while True: + import time + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--addr", default="localhost:50051", help="Listen address") + args = parser.parse_args() + serve(args.addr) diff --git a/backend/python/ace-step/install.sh b/backend/python/ace-step/install.sh new file mode 100644 index 000000000..52a7d6ff1 --- /dev/null +++ b/backend/python/ace-step/install.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +PYTHON_VERSION="3.11" +PYTHON_PATCH="14" +PY_STANDALONE_TAG="20260203" + +installRequirements + +if [ ! -d ACE-Step-1.5 ]; then + git clone https://github.com/ace-step/ACE-Step-1.5 + cd ACE-Step-1.5/ + if [ "x${USE_PIP}" == "xtrue" ]; then + pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps . + else + uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps . + fi +fi + diff --git a/backend/python/ace-step/requirements-cpu.txt b/backend/python/ace-step/requirements-cpu.txt new file mode 100644 index 000000000..e986de1b0 --- /dev/null +++ b/backend/python/ace-step/requirements-cpu.txt @@ -0,0 +1,22 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch +torchaudio +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1 +torchao +modelscope \ No newline at end of file diff --git a/backend/python/ace-step/requirements-cublas12.txt b/backend/python/ace-step/requirements-cublas12.txt new file mode 100644 index 000000000..92323b040 --- /dev/null +++ b/backend/python/ace-step/requirements-cublas12.txt @@ -0,0 +1,30 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.7.1 +torchaudio==2.7.1 +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio>=6.5.1 +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1; sys_platform != 'darwin' or platform_machine == 'arm64' +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 + +triton>=3.0.0 +flash-attn +xxhash \ No newline at end of file diff --git a/backend/python/ace-step/requirements-cublas13.txt b/backend/python/ace-step/requirements-cublas13.txt new file mode 100644 index 000000000..7f232fee4 --- /dev/null +++ b/backend/python/ace-step/requirements-cublas13.txt @@ -0,0 +1,30 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch==2.7.1 +torchaudio==2.7.1 +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio>=6.5.1 +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1; sys_platform != 'darwin' or platform_machine == 'arm64' +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 + +triton>=3.0.0 +flash-attn +xxhash \ No newline at end of file diff --git a/backend/python/ace-step/requirements-hipblas.txt b/backend/python/ace-step/requirements-hipblas.txt new file mode 100644 index 000000000..893497ea2 --- /dev/null +++ b/backend/python/ace-step/requirements-hipblas.txt @@ -0,0 +1,30 @@ +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch==2.8.0+rocm6.4 +torchaudio +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio>=6.5.1 +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1; sys_platform != 'darwin' or platform_machine == 'arm64' +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 + +triton>=3.0.0 +flash-attn +xxhash \ No newline at end of file diff --git a/backend/python/ace-step/requirements-intel.txt b/backend/python/ace-step/requirements-intel.txt new file mode 100644 index 000000000..d9136331c --- /dev/null +++ b/backend/python/ace-step/requirements-intel.txt @@ -0,0 +1,26 @@ +--extra-index-url https://download.pytorch.org/whl/xpu +torch +torchaudio +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1 +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 \ No newline at end of file diff --git a/backend/python/ace-step/requirements-l4t12.txt b/backend/python/ace-step/requirements-l4t12.txt new file mode 100644 index 000000000..5c291dd3c --- /dev/null +++ b/backend/python/ace-step/requirements-l4t12.txt @@ -0,0 +1,30 @@ +--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ +torch +torchaudio +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio>=6.5.1 +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1; sys_platform != 'darwin' or platform_machine == 'arm64' +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 + +triton>=3.0.0 +#flash-attn +xxhash \ No newline at end of file diff --git a/backend/python/ace-step/requirements-l4t13.txt b/backend/python/ace-step/requirements-l4t13.txt new file mode 100644 index 000000000..c71fb6f0b --- /dev/null +++ b/backend/python/ace-step/requirements-l4t13.txt @@ -0,0 +1,29 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch +torchaudio +torchvision +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio>=6.5.1 +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1; sys_platform != 'darwin' or platform_machine == 'arm64' +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 + +triton>=3.0.0 +#flash-attn +xxhash \ No newline at end of file diff --git a/backend/python/ace-step/requirements-mps.txt b/backend/python/ace-step/requirements-mps.txt new file mode 100644 index 000000000..ab7620136 --- /dev/null +++ b/backend/python/ace-step/requirements-mps.txt @@ -0,0 +1,25 @@ +torch +torchaudio +torchvision + +# Core dependencies +transformers>=4.51.0,<4.58.0 +diffusers +gradio +matplotlib>=3.7.5 +scipy>=1.10.1 +soundfile>=0.13.1 +loguru>=0.7.3 +einops>=0.8.1 +accelerate>=1.12.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +numba>=0.63.1 +vector-quantize-pytorch>=1.27.15 +torchcodec>=0.9.1 +torchao +modelscope + +# LoRA Training dependencies (optional) +peft>=0.7.0 +lightning>=2.0.0 \ No newline at end of file diff --git a/backend/python/ace-step/requirements.txt b/backend/python/ace-step/requirements.txt new file mode 100644 index 000000000..f9784fe6c --- /dev/null +++ b/backend/python/ace-step/requirements.txt @@ -0,0 +1,4 @@ +setuptools +grpcio==1.76.0 +protobuf +certifi diff --git a/backend/python/ace-step/run.sh b/backend/python/ace-step/run.sh new file mode 100644 index 000000000..eae121f37 --- /dev/null +++ b/backend/python/ace-step/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/ace-step/test.py b/backend/python/ace-step/test.py new file mode 100644 index 000000000..748169d1f --- /dev/null +++ b/backend/python/ace-step/test.py @@ -0,0 +1,53 @@ +""" +Tests for the ACE-Step gRPC backend. +""" +import os +import tempfile +import unittest + +import backend_pb2 +import backend_pb2_grpc +import grpc + + +class TestACEStepBackend(unittest.TestCase): + """Test Health, LoadModel, and SoundGeneration (minimal; no real model required).""" + + @classmethod + def setUpClass(cls): + port = os.environ.get("BACKEND_PORT", "50051") + cls.channel = grpc.insecure_channel(f"localhost:{port}") + cls.stub = backend_pb2_grpc.BackendStub(cls.channel) + + @classmethod + def tearDownClass(cls): + cls.channel.close() + + def test_health(self): + response = self.stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b"OK") + + def test_load_model(self): + response = self.stub.LoadModel(backend_pb2.ModelOptions(Model="ace-step-test")) + self.assertTrue(response.success, response.message) + + def test_sound_generation_minimal(self): + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + dst = f.name + try: + req = backend_pb2.SoundGenerationRequest( + text="upbeat pop song", + model="ace-step-test", + dst=dst, + ) + response = self.stub.SoundGeneration(req) + self.assertTrue(response.success, response.message) + self.assertTrue(os.path.exists(dst), f"Output file not created: {dst}") + self.assertGreater(os.path.getsize(dst), 0) + finally: + if os.path.exists(dst): + os.unlink(dst) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/python/ace-step/test.sh b/backend/python/ace-step/test.sh new file mode 100644 index 000000000..1ede318a4 --- /dev/null +++ b/backend/python/ace-step/test.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# Start backend in background (use env to avoid port conflict in parallel tests) +export PYTHONUNBUFFERED=1 +BACKEND_PORT=${BACKEND_PORT:-50051} +python backend.py --addr "localhost:${BACKEND_PORT}" & +BACKEND_PID=$! +trap "kill $BACKEND_PID 2>/dev/null || true" EXIT +sleep 3 +export BACKEND_PORT +runUnittests diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index ca78b2db9..60fdad231 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -19,6 +19,14 @@ func SoundGeneration( doSample *bool, sourceFile *string, sourceDivisor *int32, + think *bool, + caption string, + lyrics string, + bpm *int32, + keyscale string, + language string, + timesignature string, + instrumental *bool, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, @@ -45,8 +53,11 @@ func SoundGeneration( fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav") filePath := filepath.Join(audioDir, fileName) + if filePath, err = filepath.Abs(filePath); err != nil { + return "", nil, fmt.Errorf("failed resolving sound generation path: %w", err) + } - res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ + req := &proto.SoundGenerationRequest{ Text: text, Model: modelConfig.Model, Dst: filePath, @@ -55,12 +66,38 @@ func SoundGeneration( Temperature: temperature, Src: sourceFile, SrcDivisor: sourceDivisor, - }) - - // return RPC error if any - if !res.Success { - return "", nil, fmt.Errorf("error during sound generation: %s", res.Message) + } + if think != nil { + req.Think = think + } + if caption != "" { + req.Caption = &caption + } + if lyrics != "" { + req.Lyrics = &lyrics + } + if bpm != nil { + req.Bpm = bpm + } + if keyscale != "" { + req.Keyscale = &keyscale + } + if language != "" { + req.Language = &language + } + if timesignature != "" { + req.Timesignature = ×ignature + } + if instrumental != nil { + req.Instrumental = instrumental } - return filePath, res, err + res, err := soundGenModel.SoundGeneration(context.Background(), req) + if err != nil { + return "", nil, err + } + if res != nil && !res.Success { + return "", nil, fmt.Errorf("error during sound generation: %s", res.Message) + } + return filePath, res, nil } diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 5ddf96444..7c4e2afb0 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -100,7 +100,9 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { filePath, _, err := backend.SoundGeneration(text, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, - inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) + inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), + nil, "", "", nil, "", "", "", nil, + ml, opts, options) if err != nil { return err diff --git a/core/config/model_config.go b/core/config/model_config.go index 3f45e64ff..b8a2a4ffa 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -671,7 +671,8 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { } if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION { - if c.Backend != "transformers-musicgen" { + soundGenBackends := []string{"transformers-musicgen", "ace-step", "mock-backend"} + if !slices.Contains(soundGenBackends, c.Backend) { return false } } diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index d292b81cd..d634bf81d 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -32,8 +32,22 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader xlog.Debug("Sound Generation Request about to be sent to backend", "modelFile", "modelFile", "backend", cfg.Backend) - // TODO: Support uploading files? - filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) + language := input.Language + if language == "" { + language = input.VocalLanguage + } + var bpm *int32 + if input.BPM != nil { + b := int32(*input.BPM) + bpm = &b + } + filePath, _, err := backend.SoundGeneration( + input.Text, input.Duration, input.Temperature, input.DoSample, + nil, nil, + input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale, + language, input.Timesignature, + input.Instrumental, + ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index bfe93224a..b238c01ab 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -318,6 +318,49 @@ func RegisterUIRoutes(app *echo.Echo, return c.Render(200, "views/tts", summary) }) + app.GET("/sound/:model", func(c echo.Context) error { + modelConfigs := cl.GetAllModelsConfigs() + modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) + + summary := map[string]interface{}{ + "Title": "LocalAI - Generate sound with " + c.Param("model"), + "BaseURL": middleware.BaseURL(c), + "ModelsConfig": modelConfigs, + "ModelsWithoutConfig": modelsWithoutConfig, + "Model": c.Param("model"), + "Version": internal.PrintableVersion(), + } + return c.Render(200, "views/sound", summary) + }) + + app.GET("/sound", func(c echo.Context) error { + modelConfigs := cl.GetAllModelsConfigs() + modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) + + if len(modelConfigs)+len(modelsWithoutConfig) == 0 { + return c.Redirect(302, middleware.BaseURL(c)) + } + + modelThatCanBeUsed := "" + title := "LocalAI - Generate sound" + for _, b := range modelConfigs { + if b.HasUsecases(config.FLAG_SOUND_GENERATION) { + modelThatCanBeUsed = b.Name + title = "LocalAI - Generate sound with " + modelThatCanBeUsed + break + } + } + summary := map[string]interface{}{ + "Title": title, + "BaseURL": middleware.BaseURL(c), + "ModelsConfig": modelConfigs, + "ModelsWithoutConfig": modelsWithoutConfig, + "Model": modelThatCanBeUsed, + "Version": internal.PrintableVersion(), + } + return c.Render(200, "views/sound", summary) + }) + app.GET("/video/:model", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) diff --git a/core/http/static/sound.js b/core/http/static/sound.js new file mode 100644 index 000000000..74c2210b7 --- /dev/null +++ b/core/http/static/sound.js @@ -0,0 +1,144 @@ +function showNotification(type, message) { + const existing = document.getElementById('sound-notification'); + if (existing) existing.remove(); + + const notification = document.createElement('div'); + notification.id = 'sound-notification'; + notification.className = 'fixed top-24 right-4 z-50 p-4 rounded-lg shadow-lg flex items-center gap-2 transition-all duration-300'; + if (type === 'error') { + notification.classList.add('bg-red-900/90', 'border', 'border-red-700', 'text-red-200'); + notification.innerHTML = '' + message; + } else { + notification.classList.add('bg-green-900/90', 'border', 'border-green-700', 'text-green-200'); + notification.innerHTML = '' + message; + } + document.body.appendChild(notification); + setTimeout(function() { + if (document.getElementById('sound-notification')) { + document.getElementById('sound-notification').remove(); + } + }, 5000); +} + +function buildRequestBody() { + const model = document.getElementById('sound-model').value; + const body = { model_id: model }; + const isSimple = document.getElementById('mode-simple').checked; + + if (isSimple) { + const text = document.getElementById('text').value.trim(); + if (text) body.text = text; + body.instrumental = document.getElementById('instrumental').checked; + const vocal = document.getElementById('vocal_language').value.trim(); + if (vocal) body.vocal_language = vocal; + } else { + const text = document.getElementById('text').value.trim(); + if (text) body.text = text; + const caption = document.getElementById('caption').value.trim(); + if (caption) body.caption = caption; + const lyrics = document.getElementById('lyrics').value.trim(); + if (lyrics) body.lyrics = lyrics; + body.think = document.getElementById('think').checked; + const bpm = document.getElementById('bpm').value.trim(); + if (bpm) body.bpm = parseInt(bpm, 10); + const duration = document.getElementById('duration_seconds').value.trim(); + if (duration) body.duration_seconds = parseFloat(duration); + const keyscale = document.getElementById('keyscale').value.trim(); + if (keyscale) body.keyscale = keyscale; + const language = document.getElementById('language').value.trim(); + if (language) body.language = language; + const timesignature = document.getElementById('timesignature').value.trim(); + if (timesignature) body.timesignature = timesignature; + } + return body; +} + +async function generateSound(event) { + event.preventDefault(); + + const isSimple = document.getElementById('mode-simple').checked; + if (isSimple) { + const text = document.getElementById('text').value.trim(); + if (!text) { + showNotification('error', 'Please enter text (description)'); + return; + } + } else { + const text = document.getElementById('text').value.trim(); + const caption = document.getElementById('caption').value.trim(); + const lyrics = document.getElementById('lyrics').value.trim(); + if (!text && !caption && !lyrics) { + showNotification('error', 'Please enter at least text, caption, or lyrics'); + return; + } + } + + const loader = document.getElementById('loader'); + const resultDiv = document.getElementById('result'); + const generateBtn = document.getElementById('generate-btn'); + + loader.style.display = 'block'; + generateBtn.disabled = true; + resultDiv.innerHTML = '

Generating sound...

'; + + try { + const response = await fetch('v1/sound-generation', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(buildRequestBody()), + }); + + if (!response.ok) { + let errMsg = 'Request failed'; + const ct = response.headers.get('content-type'); + if (ct && ct.indexOf('application/json') !== -1) { + const json = await response.json(); + if (json && json.error && json.error.message) errMsg = json.error.message; + } + resultDiv.innerHTML = '
' + errMsg + '
'; + showNotification('error', 'Failed to generate sound'); + return; + } + + const blob = await response.blob(); + const audioUrl = window.URL.createObjectURL(blob); + + const wrap = document.createElement('div'); + wrap.className = 'flex flex-col items-center gap-4 w-full'; + + const audio = document.createElement('audio'); + audio.controls = true; + audio.src = audioUrl; + audio.className = 'w-full max-w-md'; + + const actions = document.createElement('div'); + actions.className = 'flex flex-wrap justify-center gap-3'; + + const downloadLink = document.createElement('a'); + downloadLink.href = audioUrl; + downloadLink.download = 'sound-' + new Date().toISOString().slice(0, 10) + '.wav'; + downloadLink.className = 'inline-flex items-center gap-2 px-4 py-2 rounded-lg bg-[var(--color-primary)] text-[var(--color-bg-primary)] hover:opacity-90 transition'; + downloadLink.innerHTML = ' Download'; + + actions.appendChild(downloadLink); + wrap.appendChild(audio); + wrap.appendChild(actions); + resultDiv.innerHTML = ''; + resultDiv.appendChild(wrap); + + audio.play().catch(function() {}); + showNotification('success', 'Sound generated successfully'); + } catch (err) { + console.error('Sound generation error:', err); + resultDiv.innerHTML = '
Network error
'; + showNotification('error', 'Network error'); + } finally { + loader.style.display = 'none'; + generateBtn.disabled = false; + } +} + +document.addEventListener('DOMContentLoaded', function() { + document.getElementById('sound-form').addEventListener('submit', generateSound); + document.getElementById('loader').style.display = 'none'; +}); diff --git a/core/http/views/manage.html b/core/http/views/manage.html index f117b0888..1e491ecfa 100644 --- a/core/http/views/manage.html +++ b/core/http/views/manage.html @@ -324,6 +324,11 @@ TTS {{ end }} + {{ if eq . "FLAG_SOUND_GENERATION" }} + + Sound + + {{ end }} {{ end }} diff --git a/core/http/views/partials/navbar.html b/core/http/views/partials/navbar.html index ea90bb8d1..3fb29fec6 100644 --- a/core/http/views/partials/navbar.html +++ b/core/http/views/partials/navbar.html @@ -34,6 +34,9 @@ TTS + + Sound + Talk @@ -97,6 +100,9 @@ TTS + + Sound + Talk diff --git a/core/http/views/sound.html b/core/http/views/sound.html new file mode 100644 index 000000000..e3c7a4a1a --- /dev/null +++ b/core/http/views/sound.html @@ -0,0 +1,185 @@ + + +{{template "views/partials/head" .}} + + + +
+ + {{template "views/partials/navbar" .}} +
+ +
+
+

+ Sound Generation {{ if .Model }} with {{.Model}} {{ end }} +

+

Generate music and audio from descriptions or structured prompts

+
+
+ + +
+
+ +
+
+
+ + +
+
+
+ + +
+
+
+ +

+ Use Simple mode (text as description + vocal language) or switch to Advanced mode for caption, lyrics, BPM, key, and more. +

+
+
+ + +
+ +
+ Mode: + + +
+ + +
+
+ + +
+
+ + +
+
+ +
+ + +
+
+
+ + + + +
+ +
+
+ + +
+ +
+ + +
+
+
+
+
+
+
+ + {{template "views/partials/footer" .}} +
+ + + + diff --git a/core/schema/elevenlabs.go b/core/schema/elevenlabs.go index df8a8d7c4..a8687b165 100644 --- a/core/schema/elevenlabs.go +++ b/core/schema/elevenlabs.go @@ -12,6 +12,17 @@ type ElevenLabsSoundGenerationRequest struct { Duration *float32 `json:"duration_seconds,omitempty" yaml:"duration_seconds,omitempty"` Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"` DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"` + // Advanced mode + Think *bool `json:"think,omitempty" yaml:"think,omitempty"` + Caption string `json:"caption,omitempty" yaml:"caption,omitempty"` + Lyrics string `json:"lyrics,omitempty" yaml:"lyrics,omitempty"` + BPM *int `json:"bpm,omitempty" yaml:"bpm,omitempty"` + Keyscale string `json:"keyscale,omitempty" yaml:"keyscale,omitempty"` + Language string `json:"language,omitempty" yaml:"language,omitempty"` + VocalLanguage string `json:"vocal_language,omitempty" yaml:"vocal_language,omitempty"` + Timesignature string `json:"timesignature,omitempty" yaml:"timesignature,omitempty"` + // Simple mode: use text as description; optional instrumental / vocal_language + Instrumental *bool `json:"instrumental,omitempty" yaml:"instrumental,omitempty"` } func (elttsr *ElevenLabsTTSRequest) ModelName(s *string) string { diff --git a/swagger/docs.go b/swagger/docs.go index 19554e152..ff967336c 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -2123,12 +2123,30 @@ const docTemplate = `{ "schema.ElevenLabsSoundGenerationRequest": { "type": "object", "properties": { + "bpm": { + "type": "integer" + }, + "caption": { + "type": "string" + }, "do_sample": { "type": "boolean" }, "duration_seconds": { "type": "number" }, + "instrumental": { + "type": "boolean" + }, + "keyscale": { + "type": "string" + }, + "language": { + "type": "string" + }, + "lyrics": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -2137,6 +2155,15 @@ const docTemplate = `{ }, "text": { "type": "string" + }, + "think": { + "type": "boolean" + }, + "timesignature": { + "type": "string" + }, + "vocal_language": { + "type": "string" } } }, diff --git a/swagger/swagger.json b/swagger/swagger.json index 4fdba993b..07ecd372e 100644 --- a/swagger/swagger.json +++ b/swagger/swagger.json @@ -2116,12 +2116,30 @@ "schema.ElevenLabsSoundGenerationRequest": { "type": "object", "properties": { + "bpm": { + "type": "integer" + }, + "caption": { + "type": "string" + }, "do_sample": { "type": "boolean" }, "duration_seconds": { "type": "number" }, + "instrumental": { + "type": "boolean" + }, + "keyscale": { + "type": "string" + }, + "language": { + "type": "string" + }, + "lyrics": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -2130,6 +2148,15 @@ }, "text": { "type": "string" + }, + "think": { + "type": "boolean" + }, + "timesignature": { + "type": "string" + }, + "vocal_language": { + "type": "string" } } }, diff --git a/swagger/swagger.yaml b/swagger/swagger.yaml index d32a3c573..4b6889ed9 100644 --- a/swagger/swagger.yaml +++ b/swagger/swagger.yaml @@ -403,16 +403,34 @@ definitions: type: object schema.ElevenLabsSoundGenerationRequest: properties: + bpm: + type: integer + caption: + type: string do_sample: type: boolean duration_seconds: type: number + instrumental: + type: boolean + keyscale: + type: string + language: + type: string + lyrics: + type: string model_id: type: string prompt_influence: type: number text: type: string + think: + type: boolean + timesignature: + type: string + vocal_language: + type: string type: object schema.FunctionCall: properties: diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index 375d4d7c0..66d9d6cd7 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -109,11 +109,14 @@ var _ = BeforeSuite(func() { // Create application appCtx, appCancel = context.WithCancel(context.Background()) - // Create application instance + // Create application instance (GeneratedContentDir so sound-generation/TTS can write files the handler sends) + generatedDir := filepath.Join(tmpDir, "generated") + Expect(os.MkdirAll(generatedDir, 0750)).To(Succeed()) application, err := application.New( config.WithContext(appCtx), config.WithSystemState(systemState), config.WithDebug(true), + config.WithGeneratedContentDir(generatedDir), ) Expect(err).ToNot(HaveOccurred()) diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go index 824474613..79a71fd5b 100644 --- a/tests/e2e/mock-backend/main.go +++ b/tests/e2e/mock-backend/main.go @@ -2,12 +2,14 @@ package main import ( "context" + "encoding/binary" "encoding/json" "flag" "fmt" "log" "net" "os" + "path/filepath" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/xlog" @@ -142,13 +144,64 @@ func (m *MockBackend) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamSe } func (m *MockBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { - xlog.Debug("SoundGeneration called", "text", in.Text) + xlog.Debug("SoundGeneration called", + "text", in.Text, + "caption", in.GetCaption(), + "lyrics", in.GetLyrics(), + "think", in.GetThink(), + "bpm", in.GetBpm(), + "keyscale", in.GetKeyscale(), + "language", in.GetLanguage(), + "timesignature", in.GetTimesignature(), + "instrumental", in.GetInstrumental()) + dst := in.GetDst() + if dst != "" { + if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil { + return &pb.Result{Message: err.Error(), Success: false}, nil + } + if err := writeMinimalWAV(dst); err != nil { + return &pb.Result{Message: err.Error(), Success: false}, nil + } + } return &pb.Result{ Message: "Sound generated successfully (mocked)", Success: true, }, nil } +// writeMinimalWAV writes a minimal valid WAV file (short silence) so the HTTP handler can send it. +func writeMinimalWAV(path string) error { + const sampleRate = 16000 + const numChannels = 1 + const bitsPerSample = 16 + const numSamples = 1600 // 0.1s + dataSize := numSamples * numChannels * (bitsPerSample / 8) + const headerLen = 44 + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + // RIFF header + _, _ = f.Write([]byte("RIFF")) + _ = binary.Write(f, binary.LittleEndian, uint32(headerLen-8+dataSize)) + _, _ = f.Write([]byte("WAVE")) + // fmt chunk + _, _ = f.Write([]byte("fmt ")) + _ = binary.Write(f, binary.LittleEndian, uint32(16)) + _ = binary.Write(f, binary.LittleEndian, uint16(1)) + _ = binary.Write(f, binary.LittleEndian, uint16(numChannels)) + _ = binary.Write(f, binary.LittleEndian, uint32(sampleRate)) + _ = binary.Write(f, binary.LittleEndian, uint32(sampleRate*numChannels*(bitsPerSample/8))) + _ = binary.Write(f, binary.LittleEndian, uint16(numChannels*(bitsPerSample/8))) + _ = binary.Write(f, binary.LittleEndian, uint16(bitsPerSample)) + // data chunk + _, _ = f.Write([]byte("data")) + _ = binary.Write(f, binary.LittleEndian, uint32(dataSize)) + _, _ = f.Write(make([]byte, dataSize)) + return nil +} + func (m *MockBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { xlog.Debug("AudioTranscription called") return &pb.TranscriptResult{ diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go index 1e2f739b8..241fbae2b 100644 --- a/tests/e2e/mock_backend_test.go +++ b/tests/e2e/mock_backend_test.go @@ -96,6 +96,34 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { }) }) + Describe("Sound Generation API", func() { + It("should generate mocked sound (simple mode)", func() { + body := `{"model_id":"mock-model","text":"a soft Bengali love song for a quiet evening","instrumental":false,"vocal_language":"bn"}` + req, err := http.NewRequest("POST", apiURL+"/sound-generation", io.NopCloser(strings.NewReader(body))) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + }) + + It("should generate mocked sound (advanced mode)", func() { + body := `{"model_id":"mock-model","text":"upbeat pop","caption":"A funky Japanese disco track","lyrics":"[Verse 1]\nTest lyrics","think":true,"bpm":120,"duration_seconds":225,"keyscale":"Ab major","language":"ja","timesignature":"4"}` + req, err := http.NewRequest("POST", apiURL+"/sound-generation", io.NopCloser(strings.NewReader(body))) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + }) + }) + Describe("Image Generation API", func() { It("should generate mocked image", func() { req, err := http.NewRequest("POST", apiURL+"/images/generations", nil)