diff --git a/backend/python/ace-step/backend.py b/backend/python/ace-step/backend.py index 805dd16e1..56ce1314a 100644 --- a/backend/python/ace-step/backend.py +++ b/backend/python/ace-step/backend.py @@ -28,6 +28,7 @@ from acestep.inference import ( ) from acestep.handler import AceStepHandler from acestep.llm_inference import LLMHandler +from acestep.model_downloader import ensure_lm_model _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -347,6 +348,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.llm_handler = None if self.options.get("init_lm", True): lm_model = self.options.get("lm_model_path", "acestep-5Hz-lm-0.6B") + + # Ensure LM model is downloaded before initializing + try: + from pathlib import Path + lm_success, lm_msg = ensure_lm_model( + model_name=lm_model, + checkpoints_dir=Path(self.checkpoint_dir), + prefer_source=None, # Auto-detect HuggingFace vs ModelScope + ) + if not lm_success: + print(f"[ace-step] Warning: LM model download failed: {lm_msg}", file=sys.stderr) + # Continue anyway - LLM initialization will fail gracefully + else: + print(f"[ace-step] LM model ready: {lm_msg}", file=sys.stderr) + except Exception as e: + print(f"[ace-step] Warning: LM model download check failed: {e}", file=sys.stderr) + # Continue anyway - LLM initialization will fail gracefully + self.llm_handler = LLMHandler() lm_backend = (self.options.get("lm_backend") or "vllm").strip().lower() if lm_backend not in ("vllm", "pt"): @@ -379,26 +398,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): "sample_mode": True, "thinking": True, "vocal_language": request.language or request.GetLanguage() or "en", - "instrumental": request.GetInstrumental() if request.HasField("instrumental") else False, + "instrumental": request.instrumental if request.HasField("instrumental") else False, } else: caption = request.caption or request.GetCaption() or request.text payload = { "prompt": caption, - "lyrics": request.lyrics or request.GetLyrics() or "", - "thinking": request.GetThink() if request.HasField("think") else False, + "lyrics": request.lyrics or request.lyrics or "", + "thinking": request.think if request.HasField("think") else False, "vocal_language": request.language or request.GetLanguage() or "en", } if request.HasField("bpm"): - payload["bpm"] = request.GetBpm() - if request.HasField("keyscale") and request.GetKeyscale(): - payload["key_scale"] = request.GetKeyscale() - if request.HasField("timesignature") and request.GetTimesignature(): - payload["time_signature"] = request.GetTimesignature() + payload["bpm"] = request.bpm + if request.HasField("keyscale") and request.keyscale: + payload["key_scale"] = request.keyscale + if request.HasField("timesignature") and request.timesignature: + payload["time_signature"] = request.timesignature if request.HasField("duration") and request.duration: payload["audio_duration"] = int(request.duration) if request.duration else None - if request.Src: - payload["src_audio_path"] = request.Src + if request.src: + payload["src_audio_path"] = request.src _generate_audio_sync(self, payload, request.dst) return backend_pb2.Result(success=True, message="Sound generated successfully")