mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 04:05:23 -04:00
Extend bench/eval tooling (#1905)
## Motivation Extend bench/eval tooling with robustness features, streaming support, and align model configs with vllm eval for reproducible comparisons. ## Changes - **exo_eval**: Checkpoint/resume (JSONL), instance health monitoring + early abort, `top_k`/`min_p`/`enable_thinking` params, LCB `--release-version`/`--offset` - **exo_bench**: Streaming SSE (`--stream`), Kimi tokenizer fix for transformers 5.x - **Both tools**: Auto-detect running instances instead of requiring `--skip-instance-setup`; `--fresh-instance` to override - **harness**: SSE streaming client, `find_existing_instance()` shared helper, removed download timeout, settle-timeout default 0→7200s - **models.toml**: Added `enable_thinking`, aligned `max_tokens`/temps with vllm, added new models - **API**: Streaming SSE for `/bench/chat/completions` ## Why It Works - Checkpoint/resume uses append-only JSONL + skip-on-load so interrupted evals resume without re-running completed questions - Health monitoring races an `asyncio.Event` against API calls for fast abort when the instance dies - Auto-detection queries `/state` for existing instances matching the model ID before attempting placement - Streaming reuses the existing `generate_chat_stream` infrastructure from the regular chat endpoint
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
# name, patterns, reasoning
|
||||
#
|
||||
# Optional per-model overrides (CLI flags take priority over these):
|
||||
# temperature, top_p, max_tokens, reasoning_effort
|
||||
# temperature, top_p, max_tokens, reasoning_effort, enable_thinking
|
||||
#
|
||||
# Fallback defaults (when no per-model config):
|
||||
# reasoning: temperature=1.0, max_tokens=131072, reasoning_effort="high"
|
||||
@@ -18,10 +18,9 @@
|
||||
|
||||
# ─── Qwen3.5 (Feb 2026) ─────────────────────────────────────────────
|
||||
# Source: HuggingFace model cards (Qwen/Qwen3.5-*)
|
||||
# 35B-A3B thinking general: temp=1.0, top_p=0.95, top_k=20
|
||||
# 397B thinking: temp=0.6, top_p=0.95, top_k=20
|
||||
# Non-thinking: temp=0.7, top_p=0.8, top_k=20
|
||||
# max_tokens: 32768 general, 81920 for complex math/code
|
||||
# Model card recommends: temp=0.6, top_p=0.95, top_k=20
|
||||
# We omit top_k to match vllm eval (which doesn't set it).
|
||||
# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin).
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 2B"
|
||||
@@ -29,7 +28,8 @@ patterns = ["Qwen3.5-2B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 9B"
|
||||
@@ -37,7 +37,8 @@ patterns = ["Qwen3.5-9B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 27B"
|
||||
@@ -45,15 +46,17 @@ patterns = ["Qwen3.5-27B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 35B A3B"
|
||||
patterns = ["Qwen3.5-35B-A3B"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 122B A10B"
|
||||
@@ -61,7 +64,8 @@ patterns = ["Qwen3.5-122B-A10B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Qwen3.5 397B A17B"
|
||||
@@ -69,12 +73,14 @@ patterns = ["Qwen3.5-397B-A17B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
max_tokens = 81920
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
# ─── Qwen3 (Apr 2025) ───────────────────────────────────────────────
|
||||
# Source: HuggingFace model cards (Qwen/Qwen3-*)
|
||||
# Thinking: temp=0.6, top_p=0.95, top_k=20
|
||||
# Non-thinking: temp=0.7, top_p=0.8, top_k=20
|
||||
# Model card recommends: temp=0.6, top_p=0.95, top_k=20
|
||||
# We omit top_k to match vllm eval (which doesn't set it).
|
||||
# Non-thinking: temp=0.7, top_p=0.8
|
||||
# max_tokens: 32768 general, 38912 for complex math/code
|
||||
|
||||
[[model]]
|
||||
@@ -83,6 +89,7 @@ patterns = ["Qwen3-0.6B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 38912
|
||||
|
||||
[[model]]
|
||||
@@ -91,6 +98,7 @@ patterns = ["Qwen3-30B-A3B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 38912
|
||||
|
||||
[[model]]
|
||||
@@ -99,6 +107,7 @@ patterns = ["Qwen3-235B-A22B"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 38912
|
||||
|
||||
[[model]]
|
||||
@@ -107,6 +116,7 @@ patterns = ["Qwen3-Next-80B-A3B-Thinking"]
|
||||
reasoning = true
|
||||
temperature = 0.6
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 38912
|
||||
|
||||
[[model]]
|
||||
@@ -129,9 +139,9 @@ max_tokens = 16384
|
||||
name = "Qwen3 Coder Next"
|
||||
patterns = ["Qwen3-Coder-Next"]
|
||||
reasoning = false
|
||||
temperature = 0.7
|
||||
top_p = 0.8
|
||||
max_tokens = 16384
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
max_tokens = 121072
|
||||
|
||||
# ─── GPT-OSS (OpenAI) ───────────────────────────────────────────────
|
||||
# Source: OpenAI GitHub README + HuggingFace discussion #21
|
||||
@@ -165,10 +175,38 @@ patterns = ["DeepSeek-V3.1"]
|
||||
reasoning = true
|
||||
temperature = 0.0
|
||||
|
||||
[[model]]
|
||||
name = "DeepSeek V3.2"
|
||||
patterns = ["DeepSeek-V3.2"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
|
||||
# ─── NVIDIA Nemotron ───────────────────────────────────────────────────
|
||||
# Source: HuggingFace model cards
|
||||
# All variants: temp=1.0, top_p=0.95, enable_thinking=true
|
||||
|
||||
[[model]]
|
||||
name = "Nemotron Cascade 2 30B A3B"
|
||||
patterns = ["Nemotron-Cascade-2-30B-A3B"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
|
||||
[[model]]
|
||||
name = "Nemotron 3 Super 120B A12B"
|
||||
patterns = ["Nemotron-3-Super-120B-A12B", "NVIDIA-Nemotron-3-Super-120B-A12B"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
|
||||
# ─── GLM (ZhipuAI / THUDM) ──────────────────────────────────────────
|
||||
# Source: HuggingFace model cards + generation_config.json + docs.z.ai
|
||||
# GLM 4.5+: temp=1.0, top_p=0.95
|
||||
# Reasoning tasks: 131072 max_tokens; coding/SWE tasks: temp=0.7
|
||||
# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin)
|
||||
|
||||
[[model]]
|
||||
name = "GLM-5"
|
||||
@@ -176,7 +214,8 @@ patterns = ["GLM-5"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
max_tokens = 131072
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "GLM 4.5 Air"
|
||||
@@ -191,7 +230,8 @@ patterns = ["GLM-4.7-"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
max_tokens = 131072
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
# Note: matches both GLM-4.7 and GLM-4.7-Flash
|
||||
|
||||
# ─── Kimi (Moonshot AI) ─────────────────────────────────────────────
|
||||
@@ -213,7 +253,8 @@ patterns = ["Kimi-K2.5"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
max_tokens = 131072
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
[[model]]
|
||||
name = "Kimi K2 Instruct"
|
||||
@@ -223,7 +264,17 @@ temperature = 0.6
|
||||
|
||||
# ─── MiniMax ─────────────────────────────────────────────────────────
|
||||
# Source: HuggingFace model cards + generation_config.json
|
||||
# All models: temp=1.0, top_p=0.95, top_k=40
|
||||
# All models: temp=1.0, top_p=0.95
|
||||
# max_tokens=90000 to match vllm eval (100000 context - 10000 safety margin)
|
||||
|
||||
[[model]]
|
||||
name = "MiniMax M2.7"
|
||||
patterns = ["MiniMax-M2.7"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 90000
|
||||
|
||||
[[model]]
|
||||
name = "MiniMax M2.5"
|
||||
@@ -231,6 +282,8 @@ patterns = ["MiniMax-M2.5"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 90000
|
||||
|
||||
[[model]]
|
||||
name = "MiniMax M2.1"
|
||||
@@ -251,6 +304,8 @@ patterns = ["Step-3.5-Flash"]
|
||||
reasoning = true
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
enable_thinking = true
|
||||
max_tokens = 121072
|
||||
|
||||
# ─── Llama (Meta) ───────────────────────────────────────────────────
|
||||
# Source: generation_config.json + meta-llama/llama-models generation.py
|
||||
|
||||
@@ -35,6 +35,7 @@ from harness import (
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
capture_cluster_snapshot,
|
||||
find_existing_instance,
|
||||
instance_id_from_instance,
|
||||
node_ids_from_instance,
|
||||
nodes_used_in_instance,
|
||||
@@ -79,7 +80,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model", "*.jinja"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -277,28 +278,72 @@ def run_one_completion(
|
||||
prompt_sizer: PromptSizer,
|
||||
*,
|
||||
use_prefix_cache: bool = False,
|
||||
stream: bool = False,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
"logprobs": False,
|
||||
"use_prefix_cache": use_prefix_cache,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
if not stream:
|
||||
payload["stream"] = False
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
stats = out.get("generation_stats")
|
||||
choices = out.get("choices") or [{}]
|
||||
message = choices[0].get("message", {}) if choices else {}
|
||||
content = message.get("content") or ""
|
||||
preview = content[:200] if content else ""
|
||||
else:
|
||||
tokens = 0
|
||||
first_token_time = None
|
||||
t0 = time.perf_counter()
|
||||
text_parts: list[str] = []
|
||||
stats = None
|
||||
|
||||
# Extract preview, handling None content (common for thinking models)
|
||||
choices = out.get("choices") or [{}]
|
||||
message = choices[0].get("message", {}) if choices else {}
|
||||
content = message.get("content") or ""
|
||||
preview = content[:200] if content else ""
|
||||
for raw_line in client.stream_bench_chat_completions(payload):
|
||||
line = raw_line.strip()
|
||||
if line.startswith(": generation_stats "):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
stats = json.loads(line[len(": generation_stats ") :])
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if delta.get("content"):
|
||||
if first_token_time is None:
|
||||
first_token_time = time.perf_counter()
|
||||
tokens += 1
|
||||
text_parts.append(delta["content"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
preview = "".join(text_parts)[:200]
|
||||
|
||||
if not stats:
|
||||
ttft = (first_token_time - t0) if first_token_time else elapsed
|
||||
gen_time = elapsed - ttft if tokens > 1 else elapsed
|
||||
gen_tps = (tokens - 1) / gen_time if tokens > 1 and gen_time > 0 else 0.0
|
||||
prompt_tps = pp_tokens / ttft if ttft > 0 else 0.0
|
||||
stats = {
|
||||
"prompt_tokens": pp_tokens,
|
||||
"generation_tokens": tokens,
|
||||
"prompt_tps": round(prompt_tps, 2),
|
||||
"generation_tps": round(gen_tps, 2),
|
||||
"peak_memory_usage": {"inBytes": 0},
|
||||
}
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
@@ -425,6 +470,11 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use /bench/chat/completions with streaming SSE response (bench=True still applies: no EOS detection, no KV cache).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-system-metrics",
|
||||
action="store_true",
|
||||
@@ -490,81 +540,124 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
)
|
||||
# Optionally reuse a running instance for this model
|
||||
reused_instance_id: str | None = None
|
||||
if args.reuse_instance:
|
||||
existing = find_existing_instance(client, full_model_id)
|
||||
if existing:
|
||||
reused_instance_id = existing
|
||||
logger.info(f"Reusing existing instance {reused_instance_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
"--reuse-instance: no existing instance found, creating a new one"
|
||||
)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
if reused_instance_id is not None:
|
||||
# Use the existing instance directly — skip placement iteration
|
||||
selected = []
|
||||
download_duration_s = None
|
||||
else:
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
download_duration_s = run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
if download_duration_s is not None:
|
||||
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
|
||||
else:
|
||||
logger.info("Download: model already cached")
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
download_duration_s = run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
if download_duration_s is not None:
|
||||
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
|
||||
else:
|
||||
logger.info("Download: model already cached")
|
||||
|
||||
cluster_snapshot = capture_cluster_snapshot(client)
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
all_system_metrics: dict[str, dict[str, dict[str, float]]] = {}
|
||||
|
||||
# If reusing an existing instance, run a single benchmark pass against it
|
||||
if reused_instance_id is not None:
|
||||
selected = [None]
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
created_instance = False
|
||||
if preview is not None:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
# Delete any existing instances to free resources before placing
|
||||
try:
|
||||
state = client.request_json("GET", "/state")
|
||||
for old_id in list(state.get("instances", {}).keys()):
|
||||
logger.info(f"Deleting stale instance {old_id}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{old_id}")
|
||||
if state.get("instances"):
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up stale instances: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
|
||||
time.sleep(1)
|
||||
created_instance = True
|
||||
else:
|
||||
instance_id = reused_instance_id
|
||||
sharding = "reused"
|
||||
instance_meta = "reused"
|
||||
n_nodes = 0
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Using existing instance {instance_id}")
|
||||
|
||||
sampler: SystemMetricsSampler | None = None
|
||||
if not args.no_system_metrics:
|
||||
if not args.no_system_metrics and preview is not None:
|
||||
nids = node_ids_from_instance(instance)
|
||||
sampler = SystemMetricsSampler(
|
||||
ExoClient(args.host, args.port, timeout_s=30),
|
||||
@@ -573,16 +666,20 @@ def main() -> int:
|
||||
)
|
||||
sampler.start()
|
||||
|
||||
def _do_one(c: ExoClient, pp: int, tg: int) -> tuple[dict[str, Any], int]:
|
||||
return run_one_completion(
|
||||
c,
|
||||
full_model_id,
|
||||
pp,
|
||||
tg,
|
||||
prompt_sizer,
|
||||
use_prefix_cache=args.use_prefix_cache,
|
||||
stream=args.stream,
|
||||
)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client,
|
||||
full_model_id,
|
||||
pp_list[0],
|
||||
tg_list[0],
|
||||
prompt_sizer,
|
||||
use_prefix_cache=args.use_prefix_cache,
|
||||
)
|
||||
_do_one(client, pp_list[0], tg_list[0])
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
# If pp and tg lists have same length, run in tandem (zip)
|
||||
@@ -604,14 +701,7 @@ def main() -> int:
|
||||
# Sequential: single request
|
||||
try:
|
||||
inf_t0 = time.monotonic()
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client,
|
||||
full_model_id,
|
||||
pp,
|
||||
tg,
|
||||
prompt_sizer,
|
||||
use_prefix_cache=args.use_prefix_cache,
|
||||
)
|
||||
row, actual_pp_tokens = _do_one(client, pp, tg)
|
||||
inference_windows.append((inf_t0, time.monotonic()))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@@ -760,10 +850,12 @@ def main() -> int:
|
||||
gen_tps = per_req_tps * concurrency
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
def _peak_bytes(s: dict[str, Any]) -> float:
|
||||
pm = s["peak_memory_usage"]
|
||||
return pm.get("inBytes") or pm.get("in_bytes", 0)
|
||||
|
||||
peak = mean(_peak_bytes(x["stats"]) for x in runs)
|
||||
summary = (
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
@@ -788,15 +880,16 @@ def main() -> int:
|
||||
if placement_metrics:
|
||||
all_system_metrics.update(placement_metrics)
|
||||
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
if created_instance and instance_id is not None:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
time.sleep(5)
|
||||
|
||||
output: dict[str, Any] = {"runs": all_rows}
|
||||
if cluster_snapshot:
|
||||
|
||||
@@ -47,6 +47,7 @@ from harness import (
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
capture_cluster_snapshot,
|
||||
find_existing_instance,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
@@ -62,6 +63,15 @@ from loguru import logger
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MAX_RETRIES = 30
|
||||
INSTANCE_HEALTH_CHECK_AFTER = (
|
||||
3 # Check instance health after this many consecutive failures
|
||||
)
|
||||
|
||||
|
||||
class InstanceFailedError(RuntimeError):
|
||||
"""Raised when the exo instance is detected as failed/gone."""
|
||||
|
||||
|
||||
DEFAULT_MAX_TOKENS = 16_384
|
||||
REASONING_MAX_TOKENS = 131_072
|
||||
TEMPERATURE_NON_REASONING = 0.0
|
||||
@@ -271,7 +281,7 @@ def run_humaneval_test(
|
||||
|
||||
@dataclass
|
||||
class QuestionResult:
|
||||
question_id: int
|
||||
question_id: int | str
|
||||
prompt: str
|
||||
response: str
|
||||
extracted_answer: str | None
|
||||
@@ -281,7 +291,11 @@ class QuestionResult:
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
reasoning_content: str = ""
|
||||
finish_reason: str = ""
|
||||
elapsed_s: float = 0.0
|
||||
power_watts: float = 0.0
|
||||
energy_joules: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -517,6 +531,10 @@ class ApiResult:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning_tokens: int
|
||||
reasoning_content: str = ""
|
||||
finish_reason: str = ""
|
||||
power_watts: float = 0.0
|
||||
energy_joules: float = 0.0
|
||||
|
||||
|
||||
async def _call_api(
|
||||
@@ -530,6 +548,9 @@ async def _call_api(
|
||||
system_message: str | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
min_p: float | None = None,
|
||||
enable_thinking: bool | None = None,
|
||||
) -> ApiResult:
|
||||
messages = []
|
||||
if system_message:
|
||||
@@ -546,6 +567,12 @@ async def _call_api(
|
||||
body["reasoning_effort"] = reasoning_effort
|
||||
if top_p is not None:
|
||||
body["top_p"] = top_p
|
||||
if top_k is not None:
|
||||
body["top_k"] = top_k
|
||||
if min_p is not None:
|
||||
body["min_p"] = min_p
|
||||
if enable_thinking is not None:
|
||||
body["enable_thinking"] = enable_thinking
|
||||
|
||||
resp = await client.post(
|
||||
f"{base_url}/v1/chat/completions",
|
||||
@@ -554,19 +581,40 @@ async def _call_api(
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
if not content or not content.strip():
|
||||
choice = data["choices"][0]
|
||||
message = choice["message"]
|
||||
content = message.get("content") or ""
|
||||
reasoning_content = message.get("reasoning_content") or ""
|
||||
finish_reason = choice.get("finish_reason") or ""
|
||||
|
||||
# For thinking models, empty content is expected when finish_reason is "length"
|
||||
if not content.strip() and finish_reason != "length" and not reasoning_content:
|
||||
raise ValueError("Empty response from model")
|
||||
usage = data.get("usage", {})
|
||||
details = usage.get("completion_tokens_details", {})
|
||||
power = data.get("power_usage") or {}
|
||||
return ApiResult(
|
||||
content=content,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
reasoning_tokens=details.get("reasoning_tokens", 0) if details else 0,
|
||||
reasoning_content=reasoning_content,
|
||||
finish_reason=finish_reason,
|
||||
power_watts=power.get("total_avg_sys_power_watts", 0.0),
|
||||
energy_joules=power.get("total_energy_joules", 0.0),
|
||||
)
|
||||
|
||||
|
||||
async def _check_instance_health(base_url: str) -> bool:
|
||||
"""Return True if the exo instance is still reachable."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as c:
|
||||
resp = await c.get(f"{base_url}/models", timeout=5.0)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def call_with_retries(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
@@ -578,8 +626,14 @@ async def call_with_retries(
|
||||
system_message: str | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
min_p: float | None = None,
|
||||
enable_thinking: bool | None = None,
|
||||
instance_failed: asyncio.Event | None = None,
|
||||
) -> ApiResult | None:
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if instance_failed and instance_failed.is_set():
|
||||
raise InstanceFailedError("Instance already marked as failed")
|
||||
try:
|
||||
return await _call_api(
|
||||
client,
|
||||
@@ -592,8 +646,30 @@ async def call_with_retries(
|
||||
system_message,
|
||||
reasoning_effort,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
enable_thinking,
|
||||
)
|
||||
except Exception as e:
|
||||
is_conn_error = isinstance(
|
||||
e,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.RemoteProtocolError,
|
||||
ConnectionRefusedError,
|
||||
OSError,
|
||||
),
|
||||
)
|
||||
if (
|
||||
is_conn_error
|
||||
and attempt >= INSTANCE_HEALTH_CHECK_AFTER
|
||||
and not await _check_instance_health(base_url)
|
||||
):
|
||||
if instance_failed:
|
||||
instance_failed.set()
|
||||
raise InstanceFailedError(
|
||||
f"Instance is down after {attempt + 1} failures: {e}"
|
||||
) from e
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
@@ -618,10 +694,16 @@ async def evaluate_benchmark(
|
||||
max_tokens: int,
|
||||
concurrency: int = 1,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
timeout: float | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
min_p: float | None = None,
|
||||
enable_thinking: bool | None = None,
|
||||
difficulty: str | None = None,
|
||||
checkpoint_path: Path | None = None,
|
||||
release_version: str | None = None,
|
||||
) -> list[QuestionResult]:
|
||||
"""Run a benchmark. Returns per-question results."""
|
||||
import datasets
|
||||
@@ -652,7 +734,21 @@ async def evaluate_benchmark(
|
||||
ds = ds.filter(lambda x: x["difficulty"] == difficulty)
|
||||
logger.info(f"Filtered to {len(ds)} {difficulty} problems")
|
||||
|
||||
if release_version and "release_version" in ds.column_names:
|
||||
ds = ds.filter(lambda x: x["release_version"] == release_version)
|
||||
logger.info(
|
||||
f"Filtered to {len(ds)} problems with release_version={release_version}"
|
||||
)
|
||||
|
||||
# Sort by question_id to match LCB runner ordering (scenario_router.py:60).
|
||||
# This ensures [offset:offset+limit] slices select the same problems as vllm.
|
||||
if "question_id" in ds.column_names:
|
||||
ds = ds.sort("question_id")
|
||||
|
||||
total = len(ds)
|
||||
if offset > 0:
|
||||
ds = ds.select(range(min(offset, total), total))
|
||||
total = len(ds)
|
||||
if limit and limit < total:
|
||||
ds = ds.select(range(limit))
|
||||
total = limit
|
||||
@@ -660,6 +756,13 @@ async def evaluate_benchmark(
|
||||
logger.info(
|
||||
f"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, "
|
||||
f"temperature={temperature}, max_tokens={max_tokens}"
|
||||
+ (f", top_k={top_k}" if top_k is not None else "")
|
||||
+ (f", min_p={min_p}" if min_p is not None else "")
|
||||
+ (
|
||||
f", enable_thinking={enable_thinking}"
|
||||
if enable_thinking is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if config.kind == "code":
|
||||
@@ -667,16 +770,64 @@ async def evaluate_benchmark(
|
||||
"Code benchmarks execute model-generated code. Use a sandboxed environment."
|
||||
)
|
||||
|
||||
# Load checkpoint for resume
|
||||
checkpoint_data: dict[str | int, dict[str, Any]] = {}
|
||||
if checkpoint_path and checkpoint_path.exists():
|
||||
with open(checkpoint_path) as f:
|
||||
for line in f:
|
||||
entry = json.loads(line)
|
||||
checkpoint_data[entry["question_id"]] = entry
|
||||
logger.info(f"Loaded {len(checkpoint_data)} checkpointed results")
|
||||
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
instance_failed = asyncio.Event()
|
||||
results: list[QuestionResult | None] = [None] * total
|
||||
completed = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
def _get_question_id(idx: int, doc: dict) -> str | int:
|
||||
"""Get a stable question ID for checkpointing."""
|
||||
if benchmark_name == "livecodebench":
|
||||
return doc.get("question_id", idx)
|
||||
elif benchmark_name == "humaneval":
|
||||
return doc.get("task_id", idx)
|
||||
return idx
|
||||
|
||||
async def process_question(
|
||||
idx: int, doc: dict, http_client: httpx.AsyncClient
|
||||
) -> None:
|
||||
nonlocal completed
|
||||
system_msg = None
|
||||
question_id = _get_question_id(idx, doc)
|
||||
|
||||
# Bail out early if instance is already dead
|
||||
if instance_failed.is_set():
|
||||
return
|
||||
|
||||
# Check checkpoint
|
||||
if question_id in checkpoint_data:
|
||||
cached = checkpoint_data[question_id]
|
||||
results[idx] = QuestionResult(
|
||||
question_id=question_id,
|
||||
prompt=cached.get("prompt", ""),
|
||||
response=cached.get("response", ""),
|
||||
extracted_answer=cached.get("extracted_answer"),
|
||||
gold_answer=cached.get("gold_answer", ""),
|
||||
correct=cached.get("correct", False),
|
||||
error=cached.get("error"),
|
||||
prompt_tokens=cached.get("prompt_tokens", 0),
|
||||
completion_tokens=cached.get("completion_tokens", 0),
|
||||
reasoning_tokens=cached.get("reasoning_tokens", 0),
|
||||
reasoning_content=cached.get("reasoning_content", ""),
|
||||
finish_reason=cached.get("finish_reason", ""),
|
||||
elapsed_s=cached.get("elapsed_s", 0.0),
|
||||
power_watts=cached.get("power_watts", 0.0),
|
||||
energy_joules=cached.get("energy_joules", 0.0),
|
||||
)
|
||||
async with lock:
|
||||
completed += 1
|
||||
logger.info(f" [{completed}/{total}] {question_id} (cached)")
|
||||
return
|
||||
|
||||
if benchmark_name == "gpqa_diamond":
|
||||
prompt, gold = format_gpqa_question(doc, idx)
|
||||
@@ -697,24 +848,50 @@ async def evaluate_benchmark(
|
||||
raise ValueError(f"Unknown benchmark: {benchmark_name}")
|
||||
|
||||
async with semaphore:
|
||||
if instance_failed.is_set():
|
||||
return
|
||||
t0 = time.monotonic()
|
||||
api_result = await call_with_retries(
|
||||
http_client,
|
||||
base_url,
|
||||
model,
|
||||
prompt,
|
||||
temperature,
|
||||
max_tokens,
|
||||
timeout,
|
||||
system_message=system_msg,
|
||||
reasoning_effort=reasoning_effort,
|
||||
top_p=top_p,
|
||||
)
|
||||
try:
|
||||
# Race the API call against the instance_failed event
|
||||
api_task = asyncio.create_task(
|
||||
call_with_retries(
|
||||
http_client,
|
||||
base_url,
|
||||
model,
|
||||
prompt,
|
||||
temperature,
|
||||
max_tokens,
|
||||
timeout,
|
||||
system_message=system_msg,
|
||||
reasoning_effort=reasoning_effort,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
enable_thinking=enable_thinking,
|
||||
instance_failed=instance_failed,
|
||||
)
|
||||
)
|
||||
failed_waiter = asyncio.create_task(instance_failed.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
[api_task, failed_waiter],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await p
|
||||
if instance_failed.is_set() and api_task not in done:
|
||||
logger.error(f"Instance failed, aborting {question_id}")
|
||||
return
|
||||
api_result = api_task.result()
|
||||
except InstanceFailedError:
|
||||
logger.error(f"Instance failed, skipping {question_id}")
|
||||
return
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
if api_result is None:
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response="",
|
||||
extracted_answer=None,
|
||||
@@ -729,13 +906,17 @@ async def evaluate_benchmark(
|
||||
"prompt_tokens": api_result.prompt_tokens,
|
||||
"completion_tokens": api_result.completion_tokens,
|
||||
"reasoning_tokens": api_result.reasoning_tokens,
|
||||
"reasoning_content": api_result.reasoning_content,
|
||||
"finish_reason": api_result.finish_reason,
|
||||
"elapsed_s": elapsed,
|
||||
"power_watts": api_result.power_watts,
|
||||
"energy_joules": api_result.energy_joules,
|
||||
}
|
||||
|
||||
if config.kind == "mc":
|
||||
extracted = extract_mc_answer(response, valid_letters)
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer=extracted,
|
||||
@@ -749,7 +930,7 @@ async def evaluate_benchmark(
|
||||
check_aime_answer(extracted, int(gold)) if extracted else False
|
||||
)
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer=extracted,
|
||||
@@ -763,7 +944,7 @@ async def evaluate_benchmark(
|
||||
code = extract_code_block(response, preserve_indent=keep_indent)
|
||||
if code is None:
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer=None,
|
||||
@@ -778,7 +959,7 @@ async def evaluate_benchmark(
|
||||
code,
|
||||
)
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer="pass" if passed else "fail",
|
||||
@@ -793,7 +974,7 @@ async def evaluate_benchmark(
|
||||
exec_meta["sample"],
|
||||
)
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer="pass" if passed else "fail",
|
||||
@@ -804,7 +985,7 @@ async def evaluate_benchmark(
|
||||
)
|
||||
else:
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer=None,
|
||||
@@ -815,7 +996,7 @@ async def evaluate_benchmark(
|
||||
)
|
||||
else:
|
||||
result = QuestionResult(
|
||||
question_id=idx,
|
||||
question_id=question_id,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
extracted_answer=None,
|
||||
@@ -827,24 +1008,82 @@ async def evaluate_benchmark(
|
||||
|
||||
results[idx] = result
|
||||
|
||||
# Write checkpoint (skip infra failures so they get retried on resume,
|
||||
# but keep wrong answers — they are legitimate results)
|
||||
if checkpoint_path is not None and result.response:
|
||||
_write_checkpoint(checkpoint_path, result)
|
||||
|
||||
async with lock:
|
||||
completed += 1
|
||||
n = completed
|
||||
if n % max(1, total // 20) == 0 or n == total:
|
||||
correct_so_far = sum(1 for r in results if r is not None and r.correct)
|
||||
answered = sum(1 for r in results if r is not None)
|
||||
logger.info(
|
||||
f" [{n}/{total}] {correct_so_far}/{answered} correct "
|
||||
f"({correct_so_far / max(answered, 1):.1%})"
|
||||
)
|
||||
|
||||
# Log progress
|
||||
thinking_info = ""
|
||||
if result.reasoning_content:
|
||||
thinking_info = f", {len(result.reasoning_content)} chars thinking"
|
||||
logger.info(
|
||||
f" [{n}/{total}] {question_id}: {len(result.response)} chars{thinking_info}, "
|
||||
f"tokens: {result.prompt_tokens}+{result.completion_tokens} "
|
||||
f"[{result.finish_reason}]"
|
||||
+ (f" {result.extracted_answer}" if result.extracted_answer else "")
|
||||
)
|
||||
|
||||
async def _health_monitor() -> None:
|
||||
"""Periodically check if the instance is still alive."""
|
||||
# Wait a bit before first check to let things start
|
||||
await asyncio.sleep(10)
|
||||
while not instance_failed.is_set():
|
||||
if not await _check_instance_health(base_url):
|
||||
# Double-check to avoid false positives
|
||||
await asyncio.sleep(2)
|
||||
if not await _check_instance_health(base_url):
|
||||
logger.error("Health monitor: instance is down!")
|
||||
instance_failed.set()
|
||||
return
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
monitor = asyncio.create_task(_health_monitor())
|
||||
tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)]
|
||||
await asyncio.gather(*tasks)
|
||||
monitor.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await monitor
|
||||
|
||||
if instance_failed.is_set():
|
||||
completed_count = sum(1 for r in results if r is not None)
|
||||
logger.error(
|
||||
f"Instance failed! Completed {completed_count}/{total} problems. "
|
||||
f"Checkpoint saved — restart to resume remaining problems."
|
||||
)
|
||||
raise InstanceFailedError("Instance failed during evaluation")
|
||||
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def _write_checkpoint(path: Path, result: QuestionResult) -> None:
|
||||
"""Append a single result to the JSONL checkpoint file."""
|
||||
entry = {
|
||||
"question_id": result.question_id,
|
||||
"prompt": result.prompt,
|
||||
"response": result.response,
|
||||
"extracted_answer": result.extracted_answer,
|
||||
"gold_answer": result.gold_answer,
|
||||
"correct": result.correct,
|
||||
"error": result.error,
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"reasoning_tokens": result.reasoning_tokens,
|
||||
"reasoning_content": result.reasoning_content,
|
||||
"finish_reason": result.finish_reason,
|
||||
"elapsed_s": round(result.elapsed_s, 2),
|
||||
"power_watts": round(result.power_watts, 2),
|
||||
"energy_joules": round(result.energy_joules, 2),
|
||||
}
|
||||
with open(path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Results display
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -867,6 +1106,8 @@ def print_results(
|
||||
total_elapsed = sum(r.elapsed_s for r in results)
|
||||
wall_clock = max(r.elapsed_s for r in results) if results else 0.0
|
||||
avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0
|
||||
total_energy = sum(r.energy_joules for r in results)
|
||||
avg_power = sum(r.power_watts for r in results) / max(total, 1)
|
||||
|
||||
label = f"[c={concurrency}] " if concurrency is not None else ""
|
||||
print(f"\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})")
|
||||
@@ -878,6 +1119,10 @@ def print_results(
|
||||
f" | total time: {total_elapsed:.1f}s wall clock: {wall_clock:.1f}s"
|
||||
)
|
||||
print(tok_line)
|
||||
if total_energy > 0:
|
||||
print(
|
||||
f" power: avg {avg_power:.1f}W | total energy: {total_energy:.1f}J ({total_energy / 3600:.2f}Wh)"
|
||||
)
|
||||
if errors:
|
||||
print(f" API errors: {errors}")
|
||||
if no_extract:
|
||||
@@ -896,6 +1141,8 @@ def print_results(
|
||||
"total_elapsed_s": total_elapsed,
|
||||
"wall_clock_s": wall_clock,
|
||||
"avg_gen_tps": avg_gen_tps,
|
||||
"avg_power_watts": avg_power,
|
||||
"total_energy_joules": total_energy,
|
||||
}
|
||||
|
||||
|
||||
@@ -1053,7 +1300,11 @@ def save_results(
|
||||
"prompt_tokens": r.prompt_tokens,
|
||||
"completion_tokens": r.completion_tokens,
|
||||
"reasoning_tokens": r.reasoning_tokens,
|
||||
"reasoning_content": r.reasoning_content,
|
||||
"finish_reason": r.finish_reason,
|
||||
"elapsed_s": round(r.elapsed_s, 2),
|
||||
"power_watts": round(r.power_watts, 2),
|
||||
"energy_joules": round(r.energy_joules, 2),
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
@@ -1069,6 +1320,15 @@ def save_results(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _checkpoint_path(
|
||||
results_dir: str, benchmark: str, model: str, concurrency: int
|
||||
) -> Path:
|
||||
"""Return the JSONL checkpoint path for a benchmark run."""
|
||||
out_dir = Path(results_dir) / model.replace("/", "_") / benchmark
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
return out_dir / f"c{concurrency}.checkpoint.jsonl"
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
@@ -1096,6 +1356,12 @@ def main() -> int:
|
||||
default=None,
|
||||
help="Max questions per benchmark (for fast iteration).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--offset",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Skip first N questions (0-based).",
|
||||
)
|
||||
|
||||
reasoning_group = ap.add_mutually_exclusive_group()
|
||||
reasoning_group.add_argument(
|
||||
@@ -1115,6 +1381,8 @@ def main() -> int:
|
||||
"--temperature", type=float, default=None, help="Override temperature."
|
||||
)
|
||||
ap.add_argument("--top-p", type=float, default=None, help="Override top_p.")
|
||||
ap.add_argument("--top-k", type=int, default=None, help="Override top_k.")
|
||||
ap.add_argument("--min-p", type=float, default=None, help="Override min_p.")
|
||||
ap.add_argument(
|
||||
"--max-tokens", type=int, default=None, help="Override max output tokens."
|
||||
)
|
||||
@@ -1148,15 +1416,31 @@ def main() -> int:
|
||||
choices=["easy", "medium", "hard"],
|
||||
help="Filter by difficulty (livecodebench only). E.g. --difficulty hard",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--release-version",
|
||||
default=None,
|
||||
help="LCB dataset release version (livecodebench only). E.g. release_v5",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--results-dir",
|
||||
default="eval_results",
|
||||
help="Directory for result JSON files (default: eval_results).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-instance-setup",
|
||||
"--enable-thinking",
|
||||
type=lambda v: v.lower() in ("true", "1", "yes"),
|
||||
default=None,
|
||||
help="Enable thinking mode for models that support it.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Skip exo instance management (assumes model is already running).",
|
||||
help="Discard any existing checkpoint and run from scratch.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--keep-instance",
|
||||
action="store_true",
|
||||
help="Skip deleting the instance after eval (for chaining runs).",
|
||||
)
|
||||
|
||||
args, _ = ap.parse_known_args()
|
||||
@@ -1177,13 +1461,26 @@ def main() -> int:
|
||||
# Instance management
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
instance_id: str | None = None
|
||||
created_instance = False
|
||||
|
||||
if not args.skip_instance_setup:
|
||||
short_id, full_model_id = resolve_model_short_id(
|
||||
client,
|
||||
args.model,
|
||||
force_download=args.force_download,
|
||||
)
|
||||
_short_id, full_model_id = resolve_model_short_id(
|
||||
client,
|
||||
args.model,
|
||||
force_download=args.force_download,
|
||||
)
|
||||
|
||||
# Optionally reuse a running instance for this model
|
||||
if args.reuse_instance:
|
||||
existing = find_existing_instance(client, full_model_id)
|
||||
if existing:
|
||||
instance_id = existing
|
||||
logger.info(f"Reusing existing instance {instance_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
"--reuse-instance: no existing instance found, creating a new one"
|
||||
)
|
||||
|
||||
if instance_id is None:
|
||||
selected = settle_and_fetch_placements(
|
||||
client,
|
||||
full_model_id,
|
||||
@@ -1198,7 +1495,7 @@ def main() -> int:
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
@@ -1225,6 +1522,18 @@ def main() -> int:
|
||||
if download_duration is not None:
|
||||
logger.info(f"Download: {download_duration:.1f}s")
|
||||
|
||||
# Delete any existing instances to free resources before placing
|
||||
try:
|
||||
state = client.request_json("GET", "/state")
|
||||
for old_id in list(state.get("instances", {}).keys()):
|
||||
logger.info(f"Deleting stale instance {old_id}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{old_id}")
|
||||
if state.get("instances"):
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up stale instances: {e}")
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
@@ -1234,10 +1543,9 @@ def main() -> int:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
return 1
|
||||
time.sleep(1)
|
||||
cluster_snapshot = capture_cluster_snapshot(client)
|
||||
else:
|
||||
full_model_id = args.model
|
||||
cluster_snapshot = None
|
||||
created_instance = True
|
||||
|
||||
cluster_snapshot = capture_cluster_snapshot(client)
|
||||
|
||||
# Auto-detect reasoning from model config
|
||||
model_config = load_model_config(full_model_id)
|
||||
@@ -1291,16 +1599,57 @@ def main() -> int:
|
||||
reasoning_effort = str(cfg["reasoning_effort"])
|
||||
else:
|
||||
reasoning_effort = "high" if is_reasoning else None
|
||||
|
||||
if args.top_k is not None:
|
||||
top_k: int | None = args.top_k
|
||||
elif "top_k" in cfg:
|
||||
top_k = int(cfg["top_k"])
|
||||
else:
|
||||
top_k = None
|
||||
|
||||
if args.min_p is not None:
|
||||
min_p: float | None = args.min_p
|
||||
elif "min_p" in cfg:
|
||||
min_p = float(cfg["min_p"])
|
||||
else:
|
||||
min_p = None
|
||||
|
||||
if args.enable_thinking is not None:
|
||||
enable_thinking: bool | None = args.enable_thinking
|
||||
elif "enable_thinking" in cfg:
|
||||
enable_thinking = bool(cfg["enable_thinking"])
|
||||
else:
|
||||
enable_thinking = None
|
||||
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
logger.info(f"Model: {full_model_id}")
|
||||
logger.info(
|
||||
f"Settings: temperature={temperature}, max_tokens={max_tokens}, "
|
||||
+ (f"top_p={top_p}, " if top_p is not None else "")
|
||||
+ (f"top_k={top_k}, " if top_k is not None else "")
|
||||
+ (f"min_p={min_p}, " if min_p is not None else "")
|
||||
+ f"reasoning={'yes' if is_reasoning else 'no'}"
|
||||
+ (f", reasoning_effort={reasoning_effort}" if reasoning_effort else "")
|
||||
+ (
|
||||
f", enable_thinking={enable_thinking}"
|
||||
if enable_thinking is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
# Common kwargs for evaluate_benchmark
|
||||
eval_kwargs: dict[str, Any] = {
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
"enable_thinking": enable_thinking,
|
||||
"difficulty": args.difficulty,
|
||||
"offset": args.offset,
|
||||
"release_version": args.release_version,
|
||||
}
|
||||
|
||||
try:
|
||||
if args.compare_concurrency:
|
||||
concurrency_levels = parse_int_list(args.compare_concurrency)
|
||||
@@ -1309,6 +1658,11 @@ def main() -> int:
|
||||
for c in concurrency_levels:
|
||||
logger.info(f"\n{'=' * 50}")
|
||||
logger.info(f"Running {task_name} at concurrency={c}")
|
||||
checkpoint_path = _checkpoint_path(
|
||||
args.results_dir, task_name, full_model_id, c
|
||||
)
|
||||
if args.force and checkpoint_path.exists():
|
||||
checkpoint_path.unlink()
|
||||
results = asyncio.run(
|
||||
evaluate_benchmark(
|
||||
task_name,
|
||||
@@ -1319,9 +1673,8 @@ def main() -> int:
|
||||
concurrency=c,
|
||||
limit=args.limit,
|
||||
timeout=args.request_timeout,
|
||||
reasoning_effort=reasoning_effort,
|
||||
top_p=top_p,
|
||||
difficulty=args.difficulty,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**eval_kwargs,
|
||||
)
|
||||
)
|
||||
if results:
|
||||
@@ -1336,10 +1689,18 @@ def main() -> int:
|
||||
cluster=cluster_snapshot,
|
||||
)
|
||||
results_by_c[c] = results
|
||||
# Clean up checkpoint on success
|
||||
if checkpoint_path.exists():
|
||||
checkpoint_path.unlink()
|
||||
if len(results_by_c) >= 2:
|
||||
print_comparison(task_name, results_by_c)
|
||||
else:
|
||||
for task_name in task_names:
|
||||
checkpoint_path = _checkpoint_path(
|
||||
args.results_dir, task_name, full_model_id, args.num_concurrent
|
||||
)
|
||||
if args.force and checkpoint_path.exists():
|
||||
checkpoint_path.unlink()
|
||||
results = asyncio.run(
|
||||
evaluate_benchmark(
|
||||
task_name,
|
||||
@@ -1350,9 +1711,8 @@ def main() -> int:
|
||||
concurrency=args.num_concurrent,
|
||||
limit=args.limit,
|
||||
timeout=args.request_timeout,
|
||||
reasoning_effort=reasoning_effort,
|
||||
top_p=top_p,
|
||||
difficulty=args.difficulty,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**eval_kwargs,
|
||||
)
|
||||
)
|
||||
if results:
|
||||
@@ -1366,14 +1726,25 @@ def main() -> int:
|
||||
scores,
|
||||
cluster=cluster_snapshot,
|
||||
)
|
||||
# Clean up checkpoint on success
|
||||
if checkpoint_path.exists():
|
||||
checkpoint_path.unlink()
|
||||
finally:
|
||||
if instance_id is not None:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
if created_instance and instance_id is not None:
|
||||
if args.keep_instance:
|
||||
logger.info(f"Keeping instance {instance_id} (--keep-instance)")
|
||||
else:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
try:
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Timed out waiting for instance {instance_id} to be deleted"
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -69,6 +70,30 @@ class ExoClient:
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
def stream_bench_chat_completions(self, payload: dict[str, Any]) -> Iterator[str]:
|
||||
"""POST /bench/chat/completions with stream=True, yielding raw SSE lines."""
|
||||
payload = {**payload, "stream": True}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
conn.request(
|
||||
"POST",
|
||||
"/bench/chat/completions",
|
||||
body=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
},
|
||||
)
|
||||
resp = conn.getresponse()
|
||||
if resp.status >= 400:
|
||||
raw = resp.read().decode("utf-8", errors="replace")
|
||||
raise ExoHttpError(resp.status, resp.reason, raw[:300])
|
||||
for line in resp:
|
||||
yield line.decode("utf-8", errors="replace")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_state_path(self, path: str) -> Any:
|
||||
try:
|
||||
return self.request_json("GET", f"/state/{path}")
|
||||
@@ -462,9 +487,8 @@ def run_planning_phase(
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
# Wait for downloads (no timeout — poll until complete or failed)
|
||||
while True:
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
node_downloads = client.get_node_downloads(node_id) or []
|
||||
@@ -514,9 +538,24 @@ def run_planning_phase(
|
||||
if download_t0 is not None:
|
||||
return time.perf_counter() - download_t0
|
||||
return None
|
||||
time.sleep(1)
|
||||
time.sleep(10)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
def find_existing_instance(client: ExoClient, model_id: str) -> str | None:
|
||||
"""Find an existing running instance for the given model."""
|
||||
try:
|
||||
state = client.request_json("GET", "/state")
|
||||
except Exception:
|
||||
return None
|
||||
for inst_id, inst in state.get("instances", {}).items():
|
||||
# Instance structure is nested: {"MlxJacclInstance": {"shardAssignments": {"modelId": ...}}}
|
||||
for _inst_type, inner in inst.items():
|
||||
if not isinstance(inner, dict):
|
||||
continue
|
||||
sa = inner.get("shardAssignments", {})
|
||||
if sa.get("modelId") == model_id:
|
||||
return inst_id
|
||||
return None
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
@@ -572,3 +611,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--reuse-instance",
|
||||
action="store_true",
|
||||
help="Reuse an existing running instance for this model instead of creating a new one.",
|
||||
)
|
||||
|
||||
@@ -336,7 +336,9 @@ class API:
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/bench/chat/completions", response_model=None)(
|
||||
self.bench_chat_completions
|
||||
)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
@@ -829,7 +831,7 @@ class API:
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
) -> BenchChatCompletionResponse:
|
||||
) -> BenchChatCompletionResponse | StreamingResponse:
|
||||
task_params = await chat_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
@@ -846,6 +848,22 @@ class API:
|
||||
|
||||
command = await self._send_text_generation_with_images(task_params)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
with_sse_keepalive(
|
||||
generate_chat_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
|
||||
Reference in New Issue
Block a user