diff --git a/backend/python/ace-step/backend.py b/backend/python/ace-step/backend.py index 6b61b30cc..805dd16e1 100644 --- a/backend/python/ace-step/backend.py +++ b/backend/python/ace-step/backend.py @@ -47,86 +47,6 @@ MODEL_REPO_MAPPING = { } DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" - -def _can_access_google(timeout=3.0): - """Check if Google is accessible (to choose HuggingFace vs ModelScope).""" - import socket - try: - socket.setdefaulttimeout(timeout) - socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect(("www.google.com", 443)) - return True - except (socket.timeout, socket.error, OSError): - return False - - -def _download_from_huggingface(repo_id, local_dir, model_name): - """Download model from HuggingFace Hub.""" - from huggingface_hub import snapshot_download - is_unified = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" - if is_unified: - download_dir = local_dir - print(f"[ace-step] Downloading unified repo {repo_id} to {download_dir}...", file=sys.stderr) - else: - download_dir = os.path.join(local_dir, model_name) - os.makedirs(download_dir, exist_ok=True) - print(f"[ace-step] Downloading {model_name} from {repo_id} to {download_dir}...", file=sys.stderr) - snapshot_download( - repo_id=repo_id, - local_dir=download_dir, - local_dir_use_symlinks=False, - ) - return os.path.join(local_dir, model_name) - - -def _download_from_modelscope(repo_id, local_dir, model_name): - """Download model from ModelScope.""" - from modelscope import snapshot_download - is_unified = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" - if is_unified: - download_dir = local_dir - print(f"[ace-step] Downloading unified repo {repo_id} from ModelScope to {download_dir}...", file=sys.stderr) - else: - download_dir = os.path.join(local_dir, model_name) - os.makedirs(download_dir, exist_ok=True) - print(f"[ace-step] Downloading {model_name} from ModelScope {repo_id} to {download_dir}...", file=sys.stderr) - snapshot_download( - model_id=repo_id, - local_dir=download_dir, - ) - return os.path.join(local_dir, model_name) - - -def _ensure_model_downloaded(model_name, checkpoint_dir): - """ - Ensure model is present; download from HuggingFace or ModelScope if missing. - model_name: e.g. "acestep-v15-turbo", "vae", "acestep-5Hz-lm-0.6B" - checkpoint_dir: directory that will contain model_name as a subdir. - Returns path to the model directory. - """ - return None - if not model_name or not checkpoint_dir: - return None - model_path = os.path.join(checkpoint_dir, model_name) - if os.path.exists(model_path) and os.listdir(model_path): - print(f"[ace-step] Model {model_name} already at {model_path}", file=sys.stderr) - return model_path - repo_id = MODEL_REPO_MAPPING.get(model_name, DEFAULT_REPO_ID) - print(f"[ace-step] Model {model_name} not found, downloading...", file=sys.stderr) - use_hf = _can_access_google() - if use_hf: - try: - return _download_from_huggingface(repo_id, checkpoint_dir, model_name) - except Exception as e: - print(f"[ace-step] HuggingFace download failed: {e}, trying ModelScope", file=sys.stderr) - return _download_from_modelscope(repo_id, checkpoint_dir, model_name) - else: - try: - return _download_from_modelscope(repo_id, checkpoint_dir, model_name) - except Exception as e: - print(f"[ace-step] ModelScope download failed: {e}, trying HuggingFace", file=sys.stderr) - return _download_from_huggingface(repo_id, checkpoint_dir, model_name) - - def _is_float(s): try: float(s) @@ -402,17 +322,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): config_path = model_name or os.path.basename(model_dir.rstrip("/\\")) os.makedirs(self.checkpoint_dir, exist_ok=True) - # Auto-download DiT model and VAE if missing (same as upstream) - if config_path: - try: - _ensure_model_downloaded(config_path, self.checkpoint_dir) - except Exception as e: - print(f"[ace-step] Warning: DiT model download failed: {e}", file=sys.stderr) - try: - _ensure_model_downloaded("vae", self.checkpoint_dir) - except Exception as e: - print(f"[ace-step] Warning: VAE download failed: {e}", file=sys.stderr) - self.dit_handler = AceStepHandler() # Patch handler so it uses our model dir instead of site-packages/checkpoints self.dit_handler._get_project_root = lambda: self.project_root @@ -438,11 +347,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.llm_handler = None if self.options.get("init_lm", True): lm_model = self.options.get("lm_model_path", "acestep-5Hz-lm-0.6B") - if lm_model: - try: - _ensure_model_downloaded(lm_model, self.checkpoint_dir) - except Exception as e: - print(f"[ace-step] Warning: LM model download failed: {e}", file=sys.stderr) self.llm_handler = LLMHandler() lm_backend = (self.options.get("lm_backend") or "vllm").strip().lower() if lm_backend not in ("vllm", "pt"):