diff --git a/bench/exo_bench.py b/bench/exo_bench.py index 7acbfb743..621282dad 100644 --- a/bench/exo_bench.py +++ b/bench/exo_bench.py @@ -370,6 +370,11 @@ def main() -> int: ap.add_argument( "--dry-run", action="store_true", help="List selected placements and exit." ) + ap.add_argument( + "--skip-instance-setup", + action="store_true", + help="Skip planning/download/placement (assumes model instance is already running).", + ) ap.add_argument( "--all-combinations", action="store_true", @@ -440,81 +445,110 @@ def main() -> int: logger.error("[exo-bench] tokenizer usable but prompt sizing failed") raise - selected = settle_and_fetch_placements( - client, full_model_id, args, settle_timeout=args.settle_timeout - ) - - if not selected: - logger.error("No valid placements matched your filters.") - return 1 - - selected.sort( - key=lambda p: ( - str(p.get("instance_meta", "")), - str(p.get("sharding", "")), - nodes_used_in_instance(p["instance"]), - ), - reverse=True, - ) - - logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}") - logger.info(f"placements: {len(selected)}") - for p in selected: - logger.info( - f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}" + if args.skip_instance_setup: + selected = [] + download_duration_s = None + logger.info("Skipping instance setup — assuming model is already running") + else: + selected = settle_and_fetch_placements( + client, full_model_id, args, settle_timeout=args.settle_timeout ) - if args.dry_run: - return 0 + if not selected: + logger.error("No valid placements matched your filters.") + return 1 - settle_deadline = ( - time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None - ) + selected.sort( + key=lambda p: ( + str(p.get("instance_meta", "")), + str(p.get("sharding", "")), + nodes_used_in_instance(p["instance"]), + ), + reverse=True, + ) - logger.info("Planning phase: checking downloads...") - download_duration_s = run_planning_phase( - client, - full_model_id, - selected[0], - args.danger_delete_downloads, - args.timeout, - settle_deadline, - ) - if download_duration_s is not None: - logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)") - else: - logger.info("Download: model already cached") + logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}") + logger.info(f"placements: {len(selected)}") + for p in selected: + logger.info( + f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}" + ) - cluster_snapshot = capture_cluster_snapshot(client) + if args.dry_run: + return 0 + + settle_deadline = ( + time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None + ) + + logger.info("Planning phase: checking downloads...") + download_duration_s = run_planning_phase( + client, + full_model_id, + selected[0], + args.danger_delete_downloads, + args.timeout, + settle_deadline, + ) + if download_duration_s is not None: + logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)") + else: + logger.info("Download: model already cached") + + cluster_snapshot = capture_cluster_snapshot(client) if not args.skip_instance_setup else {} all_rows: list[dict[str, Any]] = [] all_system_metrics: dict[str, dict[str, dict[str, float]]] = {} + if args.skip_instance_setup: + # Run benchmark directly without instance management + selected = [None] # Single iteration, no placement info + for preview in selected: - instance = preview["instance"] - instance_id = instance_id_from_instance(instance) + if preview is not None: + instance = preview["instance"] + instance_id = instance_id_from_instance(instance) - sharding = str(preview["sharding"]) - instance_meta = str(preview["instance_meta"]) - n_nodes = nodes_used_in_instance(instance) + sharding = str(preview["sharding"]) + instance_meta = str(preview["instance_meta"]) + n_nodes = nodes_used_in_instance(instance) - logger.info("=" * 80) - logger.info( - f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}" - ) + logger.info("=" * 80) + logger.info( + f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}" + ) - client.request_json("POST", "/instance", body={"instance": instance}) - try: - wait_for_instance_ready(client, instance_id) - except (RuntimeError, TimeoutError) as e: - logger.error(f"Failed to initialize placement: {e}") - with contextlib.suppress(ExoHttpError): - client.request_json("DELETE", f"/instance/{instance_id}") - continue + # Delete any existing instances to free resources before placing + try: + state = client.request_json("GET", "/state") + for old_id in list(state.get("instances", {}).keys()): + logger.info(f"Deleting stale instance {old_id}") + with contextlib.suppress(ExoHttpError): + client.request_json("DELETE", f"/instance/{old_id}") + if state.get("instances"): + time.sleep(2) + except Exception: + pass - time.sleep(1) + client.request_json("POST", "/instance", body={"instance": instance}) + try: + wait_for_instance_ready(client, instance_id) + except (RuntimeError, TimeoutError) as e: + logger.error(f"Failed to initialize placement: {e}") + with contextlib.suppress(ExoHttpError): + client.request_json("DELETE", f"/instance/{instance_id}") + continue + + time.sleep(1) + else: + instance_id = None + sharding = "unknown" + instance_meta = "unknown" + n_nodes = 0 + logger.info("=" * 80) + logger.info("SKIP-INSTANCE-SETUP: using existing running instance") sampler: SystemMetricsSampler | None = None - if not args.no_system_metrics: + if not args.no_system_metrics and preview is not None: nids = node_ids_from_instance(instance) sampler = SystemMetricsSampler( ExoClient(args.host, args.port, timeout_s=30), @@ -737,15 +771,16 @@ def main() -> int: if placement_metrics: all_system_metrics.update(placement_metrics) - try: - client.request_json("DELETE", f"/instance/{instance_id}") - except ExoHttpError as e: - if e.status != 404: - raise - wait_for_instance_gone(client, instance_id) - logger.debug(f"Deleted instance {instance_id}") + if instance_id is not None: + try: + client.request_json("DELETE", f"/instance/{instance_id}") + except ExoHttpError as e: + if e.status != 404: + raise + wait_for_instance_gone(client, instance_id) + logger.debug(f"Deleted instance {instance_id}") - time.sleep(5) + time.sleep(5) output: dict[str, Any] = {"runs": all_rows} if cluster_snapshot: