Files
exo/bench/prefill_decode_bench.py
ciaranbor fa57131374 Integration tests infra (#1995)
## Motivation

No automated integration tests exist for exo. Manual testing against
real hardware clusters is slow and error-prone. We need a pytest
framework that deploys clusters via `eco`, runs inference scenarios, and
tears down cleanly.

## Changes

- **`tools/src/exo_tools/`** — New workspace member shared by bench,
eval, and tests:
- `client.py` — `ExoClient` HTTP client (extracted from
`bench/harness.py`)
- `harness.py` — instance lifecycle helpers (placement, wait-for-ready,
etc.)
- `cluster.py` — `EcoSession` for eco cluster lifecycle
(deploy/stop/start/release/logs/exec) with unique `USER=<prefix>-<uuid>`
per session and atexit/signal cleanup
- **`tests/integration/`** — 17 pytest tests across 5 files:
- `test_1node.py` — place, chat, multi-turn, delete, state/models
endpoints, cluster snapshot, download-from-scratch
- `test_2node.py` — parametrized tensor/jaccl + pipeline/ring inference
and multi-turn
- `test_4node.py` — parametrized 4-node pipeline/ring inference, cluster
state
- `test_resilience.py` — full disconnect/reconnect cycle (2-node →
disconnect → 1-node → reconnect → 2-node)
- `test_dashboard.py` — Playwright: dashboard loads, shows node info,
chat flow
- `helpers.py` — placement/inference helpers, re-exports from
`exo_tools`
- `conftest.py` — session-scoped cluster fixtures with constraint-based
eco reservations; `--hosts` override; `EXO_REF` env var for CI
deployments from a GitHub branch
- **`bench/`** — Updated imports from `exo_tools.client` /
`exo_tools.harness`
- **`pyproject.toml`** — Added `tools` workspace member, `playwright`
dev dep, `--ignore=tests/integration`

## Why It Works

Tests use `eco` for cluster lifecycle and `ExoClient` for API
interactions — same tools humans use. Session-scoped fixtures deploy
once per file. Unique eco users prevent test runs from interfering with
each other or manual usage.

## Test Plan

### Automated Testing

- `uv run pytest tests/integration/ -v -s` — full suite (~4-5 min, 17/17
passing)
- `uv run pytest tests/integration/ -v -s --hosts s4,s9,s10,s22` — pin
specific hosts
- `EXO_REF=main uv run pytest tests/integration/ -v` — deploy from a
GitHub branch (CI)
- `uv run pytest` — confirms integration tests are excluded from default
runs
2026-05-08 17:15:08 +01:00

785 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# type: ignore
#!/usr/bin/env python3
"""Disaggregated prefill-decode benchmark for exo (MLX → MLX).
Spins up two MLX instances on the cluster, marks one as Prefill source and
the other as Decode target via /v1/instance-links, then sends chat
completions to the API. The master routes the request to the decode
instance and stamps `prefill_endpoint` pointing at the prefill instance —
the worker decides per-request whether to ship prefill remotely
(uncached_count > REMOTE_PREFILL_MIN_TOKENS).
Usage:
uv run python bench/prefill_decode_bench.py --model <id> --pp 2048,8192 --tg 128
uv run python bench/prefill_decode_bench.py --model <id> --pp 4096 --tg 128 --repeat 3
uv run python bench/prefill_decode_bench.py --model <id> --pp 2048 --tg 128 --dry-run
"""
from __future__ import annotations
import argparse
import contextlib
import copy
import itertools
import json
import sys
import time
import tomllib
from pathlib import Path
from statistics import mean
from typing import Any
from exo_bench import (
PromptSizer,
format_peak_memory,
load_tokenizer_for_bench,
parse_int_list,
)
from exo_tools.client import ExoClient, ExoHttpError
from exo_tools.harness import (
add_common_instance_args,
instance_id_from_instance,
node_ids_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
unwrap_instance,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
def _node_id_to_friendly(client: ExoClient) -> dict[str, str]:
identities = client.get_node_identities() or {}
out: dict[str, str] = {}
for node_id, identity in identities.items():
if isinstance(identity, dict):
name = identity.get("friendlyName") or identity.get("friendly_name")
if isinstance(name, str):
out[str(node_id)] = name
return out
def _placement_node_friendly_names(
placement: dict[str, Any], id_to_friendly: dict[str, str]
) -> list[str]:
instance = placement["instance"]
return [id_to_friendly.get(nid, nid) for nid in node_ids_from_instance(instance)]
def _filter_by_node(
placements: list[dict[str, Any]],
friendly_name: str,
id_to_friendly: dict[str, str],
) -> list[dict[str, Any]]:
target = friendly_name.lower()
matched: list[dict[str, Any]] = []
for p in placements:
names = [n.lower() for n in _placement_node_friendly_names(p, id_to_friendly)]
if any(target == n or target in n for n in names):
matched.append(p)
return matched
def _node_id_by_friendly(id_to_friendly: dict[str, str], target: str) -> str | None:
target_lc = target.lower()
for nid, name in id_to_friendly.items():
if target_lc == name.lower() or target_lc in name.lower():
return nid
return None
def _load_toml(path: str) -> dict[str, Any]:
with Path(path).open("rb") as f:
return tomllib.load(f)
_TOP_LEVEL_TOML_KEYS = {
"host",
"port",
"timeout",
"settle_timeout",
"model",
"pp",
"tg",
"repeat",
"warmup",
"json_out",
"instance_meta",
"sharding",
"min_nodes",
"max_nodes",
"force_download",
"danger_delete_downloads",
"all_combinations",
}
def _inject_toml_into_argv() -> None:
"""If --config X is in sys.argv, pre-load it and inject required CLI args
(--model, --pp, --tg) so argparse's required=True checks pass."""
argv = sys.argv
if "--config" not in argv:
return
idx = argv.index("--config")
if idx + 1 >= len(argv):
return
cfg_path = argv[idx + 1]
cfg = _load_toml(cfg_path)
decode = cfg.get("decode", {})
def _has(flag: str) -> bool:
return any(a == flag or a.startswith(flag + "=") for a in argv)
# --model: prefer top-level, then [decode].model
if not _has("--model"):
model = cfg.get("model") or decode.get("model")
if model:
argv += ["--model", str(model)]
if not _has("--pp"):
pp = cfg.get("pp")
if pp:
argv += (
["--pp", *(str(x) for x in pp)]
if isinstance(pp, list)
else [
"--pp",
str(pp),
]
)
if not _has("--tg"):
tg = cfg.get("tg")
if tg:
argv += (
["--tg", *(str(x) for x in tg)]
if isinstance(tg, list)
else [
"--tg",
str(tg),
]
)
def _merge_toml_into_args(args: argparse.Namespace, cfg: dict[str, Any]) -> None:
"""Apply top-level toml keys onto args namespace where args has a default."""
for key, value in cfg.items():
if key in {"prefill", "decode"}:
continue
if key not in _TOP_LEVEL_TOML_KEYS:
continue
attr = key
current = getattr(args, attr, None)
if current in (None, [], False):
setattr(args, attr, value)
def _side_args(
base: argparse.Namespace, overrides: dict[str, Any]
) -> argparse.Namespace:
out = copy.copy(base)
for k in (
"instance_meta",
"sharding",
"min_nodes",
"max_nodes",
"skip_pipeline_jaccl",
"skip_tensor_ring",
):
if k in overrides:
setattr(out, k, overrides[k])
return out
def _pick_two_distinct_placements(
placements: list[dict[str, Any]],
) -> tuple[dict[str, Any], dict[str, Any]] | None:
if len(placements) < 2:
return None
seen_nodes: set[tuple[str, ...]] = set()
chosen: list[dict[str, Any]] = []
for p in placements:
nodes = tuple(sorted(str(n) for n in p.get("nodes", [])))
if nodes in seen_nodes:
continue
seen_nodes.add(nodes)
chosen.append(p)
if len(chosen) == 2:
return chosen[0], chosen[1]
return None
def _create_instance_link(
client: ExoClient,
prefill_instance_id: str,
decode_instance_id: str,
) -> str:
out = client.request_json(
"POST",
"/v1/instance-links",
body={
"prefill_instances": [prefill_instance_id],
"decode_instances": [decode_instance_id],
},
)
return str(out.get("commandId", ""))
def _list_instance_links(client: ExoClient) -> list[dict[str, Any]]:
out = client.request_json("GET", "/v1/instance-links")
return out if isinstance(out, list) else []
def _delete_instance_link(client: ExoClient, link_id: str) -> None:
client.request_json("DELETE", f"/v1/instance-links/{link_id}")
def run_one(
client: ExoClient,
model_id: str,
pp_hint: int,
tg: int,
prompt_sizer: PromptSizer,
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": tg,
}
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
stats = out.get("generation_stats")
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
text = message.get("content") or ""
preview = text[:200] if text else ""
return {
"elapsed_s": elapsed,
"output_text_preview": preview,
"stats": stats,
}, pp_tokens
def _run_phase(
*,
client: ExoClient,
label: str,
pp_tg_pairs: list[tuple[int, int]],
model_id: str,
prompt_sizer: PromptSizer,
warmup: int,
repeat: int,
common_meta: dict[str, Any],
) -> list[dict[str, Any]]:
logger.info(f"=== phase: {label} (model={model_id}) ===")
rows: list[dict[str, Any]] = []
for i in range(warmup):
run_one(client, model_id, pp_tg_pairs[0][0], pp_tg_pairs[0][1], prompt_sizer)
logger.debug(f" warmup {i + 1}/{warmup} done")
for pp, tg in pp_tg_pairs:
logger.info(f"--- {label}: pp={pp} tg={tg} ---")
runs: list[dict[str, Any]] = []
for r in range(repeat):
time.sleep(2)
try:
row, actual_pp_tokens = run_one(client, model_id, pp, tg, prompt_sizer)
except Exception as e:
logger.error(e)
continue
row.update(common_meta)
row.update(
{
"phase": label,
"phase_model_id": model_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(x["stats"]["peak_memory_usage"]["inBytes"] for x in runs)
avg_elapsed = mean(x["elapsed_s"] for x in runs)
logger.info(
f"[{label}] prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)} "
f"avg_elapsed={avg_elapsed:.2f}s"
)
time.sleep(2)
return rows
def _summarise(rows: list[dict[str, Any]]) -> dict[tuple[int, int], dict[str, float]]:
grouped: dict[tuple[int, int], list[dict[str, Any]]] = {}
for r in rows:
key = (int(r["pp_tokens"]), int(r["tg"]))
grouped.setdefault(key, []).append(r)
out: dict[tuple[int, int], dict[str, float]] = {}
for key, runs in grouped.items():
out[key] = {
"prompt_tps": mean(x["stats"]["prompt_tps"] for x in runs),
"gen_tps": mean(x["stats"]["generation_tps"] for x in runs),
"elapsed_s": mean(x["elapsed_s"] for x in runs),
}
return out
def _print_diff(
disagg_rows: list[dict[str, Any]],
decode_alone_rows: list[dict[str, Any]],
prefill_alone_rows: list[dict[str, Any]],
) -> None:
disagg = _summarise(disagg_rows)
decode_alone = _summarise(decode_alone_rows)
prefill_alone = _summarise(prefill_alone_rows)
keys = set(disagg.keys()) | set(decode_alone.keys()) | set(prefill_alone.keys())
width = 64
for key in sorted(keys):
pp, tg = key
logger.info("" * width)
logger.info(f" pp={pp} tg={tg}")
logger.info("" * width)
logger.info(
f" {'phase':<16} {'elapsed':>10} {'prompt_tps':>11} {'gen_tps':>9}"
)
for label, summary in (
("disaggregated", disagg.get(key)),
("decode_alone", decode_alone.get(key)),
("prefill_alone", prefill_alone.get(key)),
):
if summary is None:
logger.info(f" {label:<16} {'':>10} {'':>11} {'':>9}")
continue
logger.info(
f" {label:<16} "
f"{summary['elapsed_s']:>9.2f}s "
f"{summary['prompt_tps']:>11.1f} "
f"{summary['gen_tps']:>9.2f}"
)
d = disagg.get(key)
da = decode_alone.get(key)
pa = prefill_alone.get(key)
if d and da and d["elapsed_s"] > 0:
logger.info(
f" speedup vs decode_alone: {da['elapsed_s'] / d['elapsed_s']:.2f}x"
)
if d and pa and d["elapsed_s"] > 0:
logger.info(
f" speedup vs prefill_alone: {pa['elapsed_s'] / d['elapsed_s']:.2f}x"
)
logger.info("" * width)
def main() -> int:
_inject_toml_into_argv()
ap = argparse.ArgumentParser(
prog="prefill-decode-bench",
description="Benchmark MLX-MLX disaggregated prefill/decode via instance links.",
)
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
required=True,
help="Prompt-size hints (ints, must be >1000). Accepts commas.",
)
ap.add_argument(
"--tg",
nargs="+",
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
ap.add_argument(
"--warmup",
type=int,
default=0,
help="Warmup runs (uses first pp/tg).",
)
ap.add_argument(
"--json-out",
default="bench/prefill_decode_results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--all-combinations",
action="store_true",
help="Force all pp×tg combinations even when lists have equal length.",
)
ap.add_argument(
"--prefill-model",
default=None,
help="Model id for the prefill instance. Defaults to --model.",
)
ap.add_argument(
"--prefill-node",
default=None,
help="friendly_name of the node hosting the prefill instance.",
)
ap.add_argument(
"--decode-node",
default=None,
help="friendly_name of the node hosting the decode instance.",
)
ap.add_argument(
"--config",
default=None,
help="TOML config file. CLI flags override toml values.",
)
ap.add_argument(
"--compare-baseline",
action="store_true",
help="Also run each (pp,tg) pair without the prefill/decode link "
"(decode instance does its own prefill) and report the diff.",
)
args = ap.parse_args()
cfg = _load_toml(args.config) if args.config else {}
_merge_toml_into_args(args, cfg)
prefill_overrides = cfg.get("prefill", {}) if cfg else {}
decode_overrides = cfg.get("decode", {}) if cfg else {}
if args.prefill_model is None and "model" in prefill_overrides:
args.prefill_model = prefill_overrides["model"]
if args.prefill_node is None and "node" in prefill_overrides:
args.prefill_node = prefill_overrides["node"]
if args.decode_node is None and "node" in decode_overrides:
args.decode_node = decode_overrides["node"]
if "model" in decode_overrides and not args.model:
args.model = decode_overrides["model"]
pp_list = parse_int_list(args.pp)
tg_list = parse_int_list(args.tg)
if not pp_list or not tg_list:
logger.error("pp and tg lists must be non-empty")
return 2
for pp in pp_list:
if pp <= 1000:
logger.error(
f"pp={pp} must be >1000 (remote prefill triggers when uncached >1000)"
)
return 2
if args.repeat <= 0:
logger.error("--repeat must be >= 1")
return 2
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
logger.info(
f"pp/tg mode: combinations (product) — {len(pp_list) * len(tg_list)} pairs"
)
else:
logger.info(f"pp/tg mode: tandem (zip) — {len(pp_list)} pairs")
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
decode_short_id, decode_full_id = resolve_model_short_id(
client, args.model, force_download=args.force_download
)
if args.prefill_model:
prefill_short_id, prefill_full_id = resolve_model_short_id(
client, args.prefill_model, force_download=args.force_download
)
else:
prefill_short_id, prefill_full_id = decode_short_id, decode_full_id
tokenizer = load_tokenizer_for_bench(decode_full_id)
if tokenizer is None:
raise RuntimeError("[prefill-decode-bench] decode tokenizer load failed")
try:
decode_prompt_sizer = PromptSizer(tokenizer)
except Exception:
logger.error("[prefill-decode-bench] decode prompt sizing failed")
raise
if prefill_full_id == decode_full_id:
prefill_prompt_sizer = decode_prompt_sizer
else:
prefill_tokenizer = load_tokenizer_for_bench(prefill_full_id)
if prefill_tokenizer is None:
raise RuntimeError("[prefill-decode-bench] prefill tokenizer load failed")
prefill_prompt_sizer = PromptSizer(prefill_tokenizer)
id_to_friendly = _node_id_to_friendly(client)
prefill_args = _side_args(args, prefill_overrides)
decode_args = _side_args(args, decode_overrides)
if prefill_full_id == decode_full_id and prefill_overrides == decode_overrides:
placements = settle_and_fetch_placements(
client, decode_full_id, args, settle_timeout=args.settle_timeout
)
prefill_candidates = (
_filter_by_node(placements, args.prefill_node, id_to_friendly)
if args.prefill_node
else placements
)
decode_candidates = (
_filter_by_node(placements, args.decode_node, id_to_friendly)
if args.decode_node
else placements
)
if args.prefill_node and not prefill_candidates:
logger.error(f"No placement on prefill node {args.prefill_node!r}.")
return 1
if args.decode_node and not decode_candidates:
logger.error(f"No placement on decode node {args.decode_node!r}.")
return 1
if args.prefill_node and args.decode_node:
prefill_p = prefill_candidates[0]
decode_p = decode_candidates[0]
else:
pair = _pick_two_distinct_placements(placements)
if pair is None:
logger.error(
"Need at least two distinct-node MLX placements for the same model."
)
return 1
prefill_p, decode_p = pair
if args.prefill_node:
prefill_p = prefill_candidates[0]
if args.decode_node:
decode_p = decode_candidates[0]
else:
prefill_node_id = (
_node_id_by_friendly(id_to_friendly, args.prefill_node)
if args.prefill_node
else None
)
decode_node_id = (
_node_id_by_friendly(id_to_friendly, args.decode_node)
if args.decode_node
else None
)
if args.prefill_node and prefill_node_id is None:
logger.error(f"Unknown node {args.prefill_node!r}.")
return 1
if args.decode_node and decode_node_id is None:
logger.error(f"Unknown node {args.decode_node!r}.")
return 1
prefill_placements = settle_and_fetch_placements(
client,
prefill_full_id,
prefill_args,
settle_timeout=args.settle_timeout,
node_id=prefill_node_id,
)
decode_placements = settle_and_fetch_placements(
client,
decode_full_id,
decode_args,
settle_timeout=args.settle_timeout,
node_id=decode_node_id,
)
if not prefill_placements:
logger.error(
f"No placement found for prefill model {prefill_full_id}"
f"{f' on node {args.prefill_node!r}' if args.prefill_node else ''}."
)
return 1
if not decode_placements:
logger.error(
f"No placement found for decode model {decode_full_id}"
f"{f' on node {args.decode_node!r}' if args.decode_node else ''}."
)
return 1
prefill_p = prefill_placements[0]
decode_p = decode_placements[0]
prefill_node_names = _placement_node_friendly_names(prefill_p, id_to_friendly)
decode_node_names = _placement_node_friendly_names(decode_p, id_to_friendly)
_ = unwrap_instance
prefill_instance = prefill_p["instance"]
decode_instance = decode_p["instance"]
prefill_id = instance_id_from_instance(prefill_instance)
decode_id = instance_id_from_instance(decode_instance)
prefill_meta = str(prefill_p.get("instance_meta", ""))
decode_meta = str(decode_p.get("instance_meta", ""))
prefill_nodes = nodes_used_in_instance(prefill_instance)
decode_nodes = nodes_used_in_instance(decode_instance)
logger.info("=" * 80)
logger.info(
f"PREFILL: {prefill_meta} / nodes={prefill_nodes} ({','.join(prefill_node_names)}) "
f"/ {prefill_short_id} ({prefill_full_id}) / instance_id={prefill_id}"
)
logger.info(
f"DECODE: {decode_meta} / nodes={decode_nodes} ({','.join(decode_node_names)}) "
f"/ {decode_short_id} ({decode_full_id}) / instance_id={decode_id}"
)
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: prefill...")
run_planning_phase(
client,
prefill_full_id,
prefill_p,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
logger.info("Planning phase: decode...")
run_planning_phase(
client,
decode_full_id,
decode_p,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if use_combinations:
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
else:
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
common_meta = {
"decode_model_short_id": decode_short_id,
"decode_model_id": decode_full_id,
"prefill_model_short_id": prefill_short_id,
"prefill_model_id": prefill_full_id,
"prefill_instance_id": prefill_id,
"prefill_instance_meta": prefill_meta,
"prefill_nodes": prefill_nodes,
"decode_instance_id": decode_id,
"decode_instance_meta": decode_meta,
"decode_nodes": decode_nodes,
}
all_rows: list[dict[str, Any]] = []
disagg_rows: list[dict[str, Any]] = []
decode_alone_rows: list[dict[str, Any]] = []
prefill_alone_rows: list[dict[str, Any]] = []
link_id = ""
prefill_alive = False
decode_alive = False
try:
logger.info("Creating prefill instance...")
client.request_json("POST", "/instance", body={"instance": prefill_instance})
wait_for_instance_ready(client, prefill_id)
prefill_alive = True
logger.info("Prefill instance ready")
if args.compare_baseline:
time.sleep(2)
prefill_alone_rows = _run_phase(
client=client,
label="prefill_alone",
pp_tg_pairs=pp_tg_pairs,
model_id=prefill_full_id,
prompt_sizer=prefill_prompt_sizer,
warmup=args.warmup,
repeat=args.repeat,
common_meta=common_meta,
)
all_rows.extend(prefill_alone_rows)
logger.info("Creating decode instance...")
client.request_json("POST", "/instance", body={"instance": decode_instance})
wait_for_instance_ready(client, decode_id)
decode_alive = True
logger.info("Decode instance ready")
logger.info("Linking instances (prefill → decode)...")
_create_instance_link(client, prefill_id, decode_id)
time.sleep(1)
links = _list_instance_links(client)
if not links:
logger.error("Link did not appear in state.")
return 1
link_id = str(links[-1].get("linkId") or links[-1].get("link_id") or "")
logger.info(f"Link created: {link_id}")
time.sleep(2)
disagg_rows = _run_phase(
client=client,
label="disaggregated",
pp_tg_pairs=pp_tg_pairs,
model_id=decode_full_id,
prompt_sizer=decode_prompt_sizer,
warmup=args.warmup,
repeat=args.repeat,
common_meta=common_meta,
)
all_rows.extend(disagg_rows)
if args.compare_baseline:
logger.info("Removing link and prefill instance to isolate decode_alone.")
with contextlib.suppress(ExoHttpError):
if link_id:
_delete_instance_link(client, link_id)
link_id = ""
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{prefill_id}")
wait_for_instance_gone(client, prefill_id)
prefill_alive = False
time.sleep(2)
decode_alone_rows = _run_phase(
client=client,
label="decode_alone",
pp_tg_pairs=pp_tg_pairs,
model_id=decode_full_id,
prompt_sizer=decode_prompt_sizer,
warmup=args.warmup,
repeat=args.repeat,
common_meta=common_meta,
)
all_rows.extend(decode_alone_rows)
_print_diff(disagg_rows, decode_alone_rows, prefill_alone_rows)
finally:
with contextlib.suppress(ExoHttpError):
if link_id:
_delete_instance_link(client, link_id)
if decode_alive:
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{decode_id}")
wait_for_instance_gone(client, decode_id)
if prefill_alive:
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{prefill_id}")
wait_for_instance_gone(client, prefill_id)
logger.debug("Deleted both instances")
if args.stdout:
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
elif args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")
return 0
if __name__ == "__main__":
sys.exit(main())