exo_bench: --skip-instance-setup flag and remove existing instances before instance creation

This commit is contained in:
ciaranbor
2026-04-04 18:12:07 +01:00
parent a865e9fcde
commit 290bfc28db

View File

@@ -370,6 +370,11 @@ def main() -> int:
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--skip-instance-setup",
action="store_true",
help="Skip planning/download/placement (assumes model instance is already running).",
)
ap.add_argument(
"--all-combinations",
action="store_true",
@@ -440,81 +445,110 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
if args.skip_instance_setup:
selected = []
download_duration_s = None
logger.info("Skipping instance setup — assuming model is already running")
else:
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
if args.dry_run:
return 0
if not selected:
logger.error("No valid placements matched your filters.")
return 1
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.info("Planning phase: checking downloads...")
download_duration_s = run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if download_duration_s is not None:
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
else:
logger.info("Download: model already cached")
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
)
cluster_snapshot = capture_cluster_snapshot(client)
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
download_duration_s = run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if download_duration_s is not None:
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
else:
logger.info("Download: model already cached")
cluster_snapshot = capture_cluster_snapshot(client) if not args.skip_instance_setup else {}
all_rows: list[dict[str, Any]] = []
all_system_metrics: dict[str, dict[str, dict[str, float]]] = {}
if args.skip_instance_setup:
# Run benchmark directly without instance management
selected = [None] # Single iteration, no placement info
for preview in selected:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
if preview is not None:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
# Delete any existing instances to free resources before placing
try:
state = client.request_json("GET", "/state")
for old_id in list(state.get("instances", {}).keys()):
logger.info(f"Deleting stale instance {old_id}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{old_id}")
if state.get("instances"):
time.sleep(2)
except Exception:
pass
time.sleep(1)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)
else:
instance_id = None
sharding = "unknown"
instance_meta = "unknown"
n_nodes = 0
logger.info("=" * 80)
logger.info("SKIP-INSTANCE-SETUP: using existing running instance")
sampler: SystemMetricsSampler | None = None
if not args.no_system_metrics:
if not args.no_system_metrics and preview is not None:
nids = node_ids_from_instance(instance)
sampler = SystemMetricsSampler(
ExoClient(args.host, args.port, timeout_s=30),
@@ -737,15 +771,16 @@ def main() -> int:
if placement_metrics:
all_system_metrics.update(placement_metrics)
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
if instance_id is not None:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
time.sleep(5)
time.sleep(5)
output: dict[str, Any] = {"runs": all_rows}
if cluster_snapshot: