Files
exo/bench/exo_eval.py

1749 lines
60 KiB
Python

# type: ignore
#!/usr/bin/env python3
"""Quality evaluation for exo — matches Artificial Analysis methodology.
Runs LLM benchmarks against exo's OpenAI-compatible API using the same
prompts, temperature settings, and answer extraction as Artificial Analysis.
Supported benchmarks:
gpqa_diamond - Graduate-level science QA (198 questions, 4-choice MC)
mmlu_pro - Multi-task language understanding (12K questions, 10-choice MC)
aime_2024 - Math olympiad 2024 (30 problems, integer answers)
aime_2025 - Math olympiad 2025 (30 problems, integer answers)
humaneval - Python code generation (164 problems, pass@1)
livecodebench - Competitive programming (880+ problems, pass@1)
Model configs in eval_configs/models.toml auto-detect reasoning/non-reasoning
settings per model. Override with --reasoning / --no-reasoning.
Usage:
uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond
uv run python exo_eval.py --model <model-id> --tasks humaneval,livecodebench --limit 50
uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond --compare-concurrency 1,4
References:
https://artificialanalysis.ai/methodology/intelligence-benchmarking
"""
from __future__ import annotations
import argparse
import asyncio
import contextlib
import json
import multiprocessing
import random
import re
import sys
import time
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import httpx
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
capture_cluster_snapshot,
find_existing_instance,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
# ---------------------------------------------------------------------------
# Artificial Analysis constants
# ---------------------------------------------------------------------------
MAX_RETRIES = 30
INSTANCE_HEALTH_CHECK_AFTER = (
3 # Check instance health after this many consecutive failures
)
class InstanceFailedError(RuntimeError):
"""Raised when the exo instance is detected as failed/gone."""
DEFAULT_MAX_TOKENS = 16_384
REASONING_MAX_TOKENS = 131_072
TEMPERATURE_NON_REASONING = 0.0
TEMPERATURE_REASONING = 1.0
# MC answer extraction: 8 fallback regex patterns.
# All patterns are tried; the match at the latest text position wins
# (handles models that self-correct during reasoning).
_MC_PATTERNS: list[re.Pattern[str]] = [
re.compile(
r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])"
),
re.compile(r"\\boxed\{[^}]*([A-Z])[^}]*\}"),
re.compile(r"(?i)answer is ([a-zA-Z])"),
re.compile(r"(?i)answer is \\\(([a-zA-Z])"),
re.compile(r"([A-Z])\)\s*[^A-Z]*$"),
re.compile(r"([A-Z])\s+is\s+the\s+correct\s+answer"),
re.compile(r"([A-Z])\s*$"),
re.compile(r"([A-Z])\s*\."),
]
# Code extraction: last ```python ... ``` block (AA regex)
_CODE_BLOCK_RE = re.compile(r"```(?:python|Python)?\s*\n(.*?)```", re.DOTALL)
# ---------------------------------------------------------------------------
# Model config loading
# ---------------------------------------------------------------------------
def load_model_config(model_id: str) -> dict[str, Any] | None:
"""Look up model in eval_configs/models.toml. Returns config dict or None."""
config_path = Path(__file__).resolve().parent / "eval_configs" / "models.toml"
if not config_path.exists():
return None
with open(config_path, "rb") as f:
data = tomllib.load(f)
for entry in data.get("model", []):
patterns = entry.get("patterns", [])
if any(p in model_id for p in patterns):
return entry
return None
# ---------------------------------------------------------------------------
# Answer extraction
# ---------------------------------------------------------------------------
def extract_mc_answer(text: str, valid_letters: str = "ABCD") -> str | None:
"""Extract MC answer. Last match by text position wins."""
valid_set = set(valid_letters)
best: tuple[int, str] | None = None
for pattern in _MC_PATTERNS:
for m in pattern.finditer(text):
letter = m.group(1).upper()
if letter in valid_set:
pos = m.start()
if best is None or pos >= best[0]:
best = (pos, letter)
return best[1] if best else None
def extract_boxed_answer(text: str) -> str | None:
r"""Extract content from the last \boxed{...}."""
matches: list[str] = []
idx = 0
while True:
pos = text.find("\\boxed{", idx)
if pos < 0:
break
depth = 0
i = pos + len("\\boxed{")
start = i
while i < len(text):
if text[i] == "{":
depth += 1
elif text[i] == "}":
if depth == 0:
matches.append(text[start:i])
break
depth -= 1
i += 1
idx = i + 1 if i < len(text) else len(text)
return matches[-1].strip() if matches else None
def extract_code_block(text: str, preserve_indent: bool = False) -> str | None:
"""Extract the last Python code block from markdown response.
If preserve_indent is True, only strip trailing whitespace (keeps leading
indentation intact — needed for HumanEval function-body completions).
"""
matches = _CODE_BLOCK_RE.findall(text)
if matches:
raw = matches[-1]
return raw.rstrip() if preserve_indent else raw.strip()
# Fallback: try raw code after last ```
lines = text.split("\n")
backtick_lines = [i for i, line in enumerate(lines) if "```" in line]
if len(backtick_lines) >= 2:
return "\n".join(lines[backtick_lines[-2] + 1 : backtick_lines[-1]])
return None
def check_aime_answer(extracted: str, gold: int) -> bool:
"""Check if extracted AIME answer matches gold integer."""
try:
return int(extracted.strip()) == gold
except ValueError:
pass
try:
from math_verify import parse, verify
return verify(parse(str(gold)), parse(extracted))
except Exception:
return False
# ---------------------------------------------------------------------------
# Code execution — official evaluation harnesses
# ---------------------------------------------------------------------------
# LiveCodeBench: vendored from https://github.com/LiveCodeBench/LiveCodeBench
# run_test() must execute in a child process because reliability_guard()
# permanently disables OS functions (os.kill, subprocess.Popen, etc.).
def _lcb_worker(
sample: dict,
code: str,
timeout: int,
result_holder: list[Any],
metadata_holder: list[Any],
) -> None:
"""Target for multiprocessing.Process — runs vendored LCB run_test."""
from vendor.lcb_testing_util import run_test
try:
results, metadata = run_test(sample, test=code, debug=False, timeout=timeout)
result_holder.append(results)
metadata_holder.append(metadata)
except Exception as e:
result_holder.append([-4])
metadata_holder.append({"error_code": -4, "error_message": str(e)})
def run_livecodebench_test(
code: str,
sample: dict,
timeout: int = 6,
) -> tuple[bool, str]:
"""Run LCB evaluation in a subprocess. Returns (passed, diagnostic_info)."""
manager = multiprocessing.Manager()
result_holder = manager.list()
metadata_holder = manager.list()
proc = multiprocessing.Process(
target=_lcb_worker,
args=(sample, code, timeout, result_holder, metadata_holder),
)
proc.start()
# Global timeout: (per-test timeout + 1) * num_tests + 5
num_tests = len(json.loads(sample["input_output"]).get("inputs", []))
global_timeout = (timeout + 1) * num_tests + 5
proc.join(timeout=global_timeout)
if proc.is_alive():
proc.kill()
proc.join()
return False, "Global timeout exceeded"
if not result_holder:
return False, "No results returned from worker"
results = list(result_holder[0])
metadata = dict(metadata_holder[0]) if metadata_holder else {}
# LCB convention: True = pass, negative int = failure code
all_passed = all(r is True or r == 1 for r in results)
if all_passed:
return True, ""
diag = metadata.get("error_message", "")
if not diag and "output" in metadata:
diag = f"Got {metadata['output']}, expected {metadata.get('expected', '?')}"
return False, diag
def run_humaneval_test(
problem: dict, completion: str, timeout: float = 10.0
) -> tuple[bool, str]:
"""Run HumanEval evaluation using the official human_eval package."""
from human_eval.execution import check_correctness
result = check_correctness(problem, completion, timeout)
passed = result["passed"]
diag = "" if passed else result.get("result", "failed")
return passed, diag
# ---------------------------------------------------------------------------
# Benchmark definitions
# ---------------------------------------------------------------------------
@dataclass
class QuestionResult:
question_id: int | str
prompt: str
response: str
extracted_answer: str | None
gold_answer: str
correct: bool
error: str | None = None
prompt_tokens: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
reasoning_content: str = ""
finish_reason: str = ""
elapsed_s: float = 0.0
power_watts: float = 0.0
energy_joules: float = 0.0
@dataclass
class BenchmarkConfig:
name: str
description: str
dataset_name: str
dataset_config: str | None
split: str
kind: str # "mc", "math", "code"
BENCHMARKS: dict[str, BenchmarkConfig] = {
"gpqa_diamond": BenchmarkConfig(
name="gpqa_diamond",
description="Graduate-level science QA (198 Q, 4-choice MC)",
dataset_name="Idavidrein/gpqa",
dataset_config="gpqa_diamond",
split="train",
kind="mc",
),
"mmlu_pro": BenchmarkConfig(
name="mmlu_pro",
description="Multi-task language understanding (12K Q, 10-choice MC)",
dataset_name="TIGER-Lab/MMLU-Pro",
dataset_config=None,
split="test",
kind="mc",
),
"aime_2024": BenchmarkConfig(
name="aime_2024",
description="Math olympiad 2024 (30 problems, integer answers)",
dataset_name="HuggingFaceH4/aime_2024",
dataset_config=None,
split="train",
kind="math",
),
"aime_2025": BenchmarkConfig(
name="aime_2025",
description="Math olympiad 2025 (30 problems, integer answers)",
dataset_name="MathArena/aime_2025",
dataset_config=None,
split="train",
kind="math",
),
"humaneval": BenchmarkConfig(
name="humaneval",
description="Python code generation (164 problems, pass@1)",
dataset_name="openai/openai_humaneval",
dataset_config=None,
split="test",
kind="code",
),
"livecodebench": BenchmarkConfig(
name="livecodebench",
description="Competitive programming (880+ problems, pass@1)",
dataset_name="livecodebench/code_generation_lite",
dataset_config=None,
split="test",
kind="code",
),
}
# ---------------------------------------------------------------------------
# Prompt formatters
# ---------------------------------------------------------------------------
_GPQA_INSTRUCTION = (
"Answer the following multiple choice question. "
"The last line of your response should be in the following format: "
"'Answer: A/B/C/D' (e.g. 'Answer: A')."
)
_MMLU_PRO_INSTRUCTION = (
"Answer the following multiple choice question. "
"The last line of your response should be in the following format: "
"'Answer: A/B/C/D/E/F/G/H/I/J' (e.g. 'Answer: A')."
)
_AIME_INSTRUCTION = (
"Solve the following math problem step by step. "
"Put your answer inside \\boxed{}.\n"
"Remember to put your answer inside \\boxed{}."
)
_HUMANEVAL_INSTRUCTION = (
"Complete the following Python function. Return only the function body "
"inside a ```python code block. Do not include the function signature."
)
# LiveCodeBench: AA uses original prompts without custom system prompts
_LCB_SYSTEM = (
"You are an expert Python programmer. You will be given a question "
"(problem specification) and will generate a correct Python program "
"that matches the specification and passes all tests."
)
_LCB_WITH_STARTER = (
"### Question:\n{question}\n\n"
"### Format: You will use the following starter code to write the "
"solution to the problem and enclose your code within delimiters.\n"
"```python\n{starter_code}\n```\n\n"
"### Answer: (use the provided format with backticks)\n"
)
_LCB_WITHOUT_STARTER = (
"### Question:\n{question}\n\n"
"### Format: Read the inputs from stdin solve the problem and write "
"the answer to stdout (do not directly test on the sample inputs). "
"Enclose your code within delimiters as follows. Ensure that when the "
"python program runs, it reads the inputs, runs the algorithm and "
"writes output to STDOUT.\n"
"```python\n# YOUR CODE HERE\n```\n\n"
"### Answer: (use the provided format with backticks)\n"
)
def format_gpqa_question(doc: dict, idx: int) -> tuple[str, str]:
"""Returns (prompt, correct_letter)."""
correct = doc["Correct Answer"]
choices = [
correct,
doc["Incorrect Answer 1"],
doc["Incorrect Answer 2"],
doc["Incorrect Answer 3"],
]
rng = random.Random(idx)
order = rng.sample(range(4), 4)
shuffled = [choices[i] for i in order]
correct_letter = "ABCD"[order.index(0)]
choices_text = "\n".join(f"{L}) {shuffled[i]}" for i, L in enumerate("ABCD"))
return f"{_GPQA_INSTRUCTION}\n\n{doc['Question']}\n\n{choices_text}", correct_letter
def format_mmlu_pro_question(doc: dict) -> tuple[str, str]:
"""Returns (prompt, correct_letter)."""
options = doc["options"]
letters = "ABCDEFGHIJ"
choices_text = "\n".join(f"{letters[i]}) {opt}" for i, opt in enumerate(options))
return f"{_MMLU_PRO_INSTRUCTION}\n\n{doc['question']}\n\n{choices_text}", doc[
"answer"
]
def format_aime_question(doc: dict) -> tuple[str, int]:
"""Returns (prompt, correct_answer_int)."""
return f"{_AIME_INSTRUCTION}\n\n{doc['problem']}", int(doc["answer"])
def format_humaneval_question(doc: dict) -> tuple[str, dict]:
"""Returns (prompt, metadata_for_execution)."""
prompt = f"{_HUMANEVAL_INSTRUCTION}\n\n```python\n{doc['prompt']}```"
# Pass the full problem dict — check_correctness needs task_id, prompt,
# test, entry_point
meta = {
"problem": {
"task_id": doc["task_id"],
"prompt": doc["prompt"],
"test": doc["test"],
"entry_point": doc["entry_point"],
},
}
return prompt, meta
def format_livecodebench_question(doc: dict) -> tuple[str, str | None, dict]:
"""Returns (prompt, system_message, metadata_for_execution)."""
starter_code = doc.get("starter_code", "")
question_content = doc["question_content"]
if starter_code and starter_code.strip():
user_msg = _LCB_WITH_STARTER.format(
question=question_content, starter_code=starter_code
)
else:
user_msg = _LCB_WITHOUT_STARTER.format(question=question_content)
# Parse test cases
public_tests = (
json.loads(doc["public_test_cases"])
if isinstance(doc["public_test_cases"], str)
else doc["public_test_cases"]
)
private_tests = doc.get("private_test_cases", "[]")
if isinstance(private_tests, str):
try:
private_tests = json.loads(private_tests)
except Exception:
import base64
import pickle
import zlib
private_tests = json.loads(
pickle.loads(
zlib.decompress(base64.b64decode(private_tests.encode("utf-8")))
)
)
all_tests = public_tests + (
private_tests if isinstance(private_tests, list) else []
)
test_inputs = [t["input"] for t in all_tests]
test_outputs = [t["output"] for t in all_tests]
metadata = doc.get("metadata", "{}")
if isinstance(metadata, str):
metadata = json.loads(metadata)
func_name = metadata.get("func_name")
# Build the sample dict in official LCB format for run_test()
input_output: dict[str, Any] = {
"inputs": test_inputs,
"outputs": test_outputs,
}
if func_name:
input_output["fn_name"] = func_name
meta = {
"sample": {"input_output": json.dumps(input_output)},
}
return user_msg, _LCB_SYSTEM, meta
# ---------------------------------------------------------------------------
# API client with retries
# ---------------------------------------------------------------------------
@dataclass
class ApiResult:
content: str
prompt_tokens: int
completion_tokens: int
reasoning_tokens: int
reasoning_content: str = ""
finish_reason: str = ""
power_watts: float = 0.0
energy_joules: float = 0.0
async def _call_api(
client: httpx.AsyncClient,
base_url: str,
model: str,
prompt: str,
temperature: float,
max_tokens: int,
timeout: float | None,
system_message: str | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
) -> ApiResult:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": prompt})
body: dict[str, Any] = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if reasoning_effort is not None:
body["reasoning_effort"] = reasoning_effort
if top_p is not None:
body["top_p"] = top_p
if top_k is not None:
body["top_k"] = top_k
if min_p is not None:
body["min_p"] = min_p
if enable_thinking is not None:
body["enable_thinking"] = enable_thinking
resp = await client.post(
f"{base_url}/v1/chat/completions",
json=body,
timeout=timeout,
)
resp.raise_for_status()
data = resp.json()
choice = data["choices"][0]
message = choice["message"]
content = message.get("content") or ""
reasoning_content = message.get("reasoning_content") or ""
finish_reason = choice.get("finish_reason") or ""
# For thinking models, empty content is expected when finish_reason is "length"
if not content.strip() and finish_reason != "length" and not reasoning_content:
raise ValueError("Empty response from model")
usage = data.get("usage", {})
details = usage.get("completion_tokens_details", {})
power = data.get("power_usage") or {}
return ApiResult(
content=content,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
reasoning_tokens=details.get("reasoning_tokens", 0) if details else 0,
reasoning_content=reasoning_content,
finish_reason=finish_reason,
power_watts=power.get("total_avg_sys_power_watts", 0.0),
energy_joules=power.get("total_energy_joules", 0.0),
)
async def _check_instance_health(base_url: str) -> bool:
"""Return True if the exo instance is still reachable."""
try:
async with httpx.AsyncClient() as c:
resp = await c.get(f"{base_url}/models", timeout=5.0)
return resp.status_code == 200
except Exception:
return False
async def call_with_retries(
client: httpx.AsyncClient,
base_url: str,
model: str,
prompt: str,
temperature: float,
max_tokens: int,
timeout: float | None = None,
system_message: str | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
instance_failed: asyncio.Event | None = None,
) -> ApiResult | None:
for attempt in range(MAX_RETRIES):
if instance_failed and instance_failed.is_set():
raise InstanceFailedError("Instance already marked as failed")
try:
return await _call_api(
client,
base_url,
model,
prompt,
temperature,
max_tokens,
timeout,
system_message,
reasoning_effort,
top_p,
top_k,
min_p,
enable_thinking,
)
except Exception as e:
is_conn_error = isinstance(
e,
(
httpx.ConnectError,
httpx.RemoteProtocolError,
ConnectionRefusedError,
OSError,
),
)
if (
is_conn_error
and attempt >= INSTANCE_HEALTH_CHECK_AFTER
and not await _check_instance_health(base_url)
):
if instance_failed:
instance_failed.set()
raise InstanceFailedError(
f"Instance is down after {attempt + 1} failures: {e}"
) from e
if attempt < MAX_RETRIES - 1:
wait = min(2**attempt, 60)
logger.warning(
f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}. Retrying in {wait}s..."
)
await asyncio.sleep(wait)
else:
logger.error(f"All {MAX_RETRIES} retries exhausted. Last error: {e}")
return None
# ---------------------------------------------------------------------------
# Evaluation runners
# ---------------------------------------------------------------------------
async def evaluate_benchmark(
benchmark_name: str,
base_url: str,
model: str,
temperature: float,
max_tokens: int,
concurrency: int = 1,
limit: int | None = None,
offset: int = 0,
timeout: float | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
difficulty: str | None = None,
checkpoint_path: Path | None = None,
release_version: str | None = None,
) -> list[QuestionResult]:
"""Run a benchmark. Returns per-question results."""
import datasets
config = BENCHMARKS[benchmark_name]
logger.info(f"Loading dataset {config.dataset_name}...")
try:
if benchmark_name == "livecodebench":
ds = datasets.load_dataset(
"json",
data_files="hf://datasets/livecodebench/code_generation_lite/*.jsonl",
split="train",
)
else:
ds = datasets.load_dataset(
config.dataset_name,
config.dataset_config,
split=config.split,
)
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
if "gated" in str(e).lower() or "login" in str(e).lower():
logger.error("Dataset requires authentication. Run: huggingface-cli login")
return []
if difficulty and "difficulty" in ds.column_names:
ds = ds.filter(lambda x: x["difficulty"] == difficulty)
logger.info(f"Filtered to {len(ds)} {difficulty} problems")
if release_version and "release_version" in ds.column_names:
ds = ds.filter(lambda x: x["release_version"] == release_version)
logger.info(
f"Filtered to {len(ds)} problems with release_version={release_version}"
)
# Sort by question_id to match LCB runner ordering (scenario_router.py:60).
# This ensures [offset:offset+limit] slices select the same problems as vllm.
if "question_id" in ds.column_names:
ds = ds.sort("question_id")
total = len(ds)
if offset > 0:
ds = ds.select(range(min(offset, total), total))
total = len(ds)
if limit and limit < total:
ds = ds.select(range(limit))
total = limit
logger.info(
f"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, "
f"temperature={temperature}, max_tokens={max_tokens}"
+ (f", top_k={top_k}" if top_k is not None else "")
+ (f", min_p={min_p}" if min_p is not None else "")
+ (
f", enable_thinking={enable_thinking}"
if enable_thinking is not None
else ""
)
)
if config.kind == "code":
logger.warning(
"Code benchmarks execute model-generated code. Use a sandboxed environment."
)
# Load checkpoint for resume
checkpoint_data: dict[str | int, dict[str, Any]] = {}
if checkpoint_path and checkpoint_path.exists():
with open(checkpoint_path) as f:
for line in f:
entry = json.loads(line)
checkpoint_data[entry["question_id"]] = entry
logger.info(f"Loaded {len(checkpoint_data)} checkpointed results")
semaphore = asyncio.Semaphore(concurrency)
instance_failed = asyncio.Event()
results: list[QuestionResult | None] = [None] * total
completed = 0
lock = asyncio.Lock()
def _get_question_id(idx: int, doc: dict) -> str | int:
"""Get a stable question ID for checkpointing."""
if benchmark_name == "livecodebench":
return doc.get("question_id", idx)
elif benchmark_name == "humaneval":
return doc.get("task_id", idx)
return idx
async def process_question(
idx: int, doc: dict, http_client: httpx.AsyncClient
) -> None:
nonlocal completed
system_msg = None
question_id = _get_question_id(idx, doc)
# Bail out early if instance is already dead
if instance_failed.is_set():
return
# Check checkpoint
if question_id in checkpoint_data:
cached = checkpoint_data[question_id]
results[idx] = QuestionResult(
question_id=question_id,
prompt=cached.get("prompt", ""),
response=cached.get("response", ""),
extracted_answer=cached.get("extracted_answer"),
gold_answer=cached.get("gold_answer", ""),
correct=cached.get("correct", False),
error=cached.get("error"),
prompt_tokens=cached.get("prompt_tokens", 0),
completion_tokens=cached.get("completion_tokens", 0),
reasoning_tokens=cached.get("reasoning_tokens", 0),
reasoning_content=cached.get("reasoning_content", ""),
finish_reason=cached.get("finish_reason", ""),
elapsed_s=cached.get("elapsed_s", 0.0),
power_watts=cached.get("power_watts", 0.0),
energy_joules=cached.get("energy_joules", 0.0),
)
async with lock:
completed += 1
logger.info(f" [{completed}/{total}] {question_id} (cached)")
return
if benchmark_name == "gpqa_diamond":
prompt, gold = format_gpqa_question(doc, idx)
valid_letters = "ABCD"
elif benchmark_name == "mmlu_pro":
prompt, gold = format_mmlu_pro_question(doc)
valid_letters = "ABCDEFGHIJ"[: len(doc["options"])]
elif benchmark_name.startswith("aime"):
prompt, gold_int = format_aime_question(doc)
gold = str(gold_int)
elif benchmark_name == "humaneval":
prompt, exec_meta = format_humaneval_question(doc)
gold = "pass"
elif benchmark_name == "livecodebench":
prompt, system_msg, exec_meta = format_livecodebench_question(doc)
gold = "pass"
else:
raise ValueError(f"Unknown benchmark: {benchmark_name}")
async with semaphore:
if instance_failed.is_set():
return
t0 = time.monotonic()
try:
# Race the API call against the instance_failed event
api_task = asyncio.create_task(
call_with_retries(
http_client,
base_url,
model,
prompt,
temperature,
max_tokens,
timeout,
system_message=system_msg,
reasoning_effort=reasoning_effort,
top_p=top_p,
top_k=top_k,
min_p=min_p,
enable_thinking=enable_thinking,
instance_failed=instance_failed,
)
)
failed_waiter = asyncio.create_task(instance_failed.wait())
done, pending = await asyncio.wait(
[api_task, failed_waiter],
return_when=asyncio.FIRST_COMPLETED,
)
for p in pending:
p.cancel()
with contextlib.suppress(asyncio.CancelledError):
await p
if instance_failed.is_set() and api_task not in done:
logger.error(f"Instance failed, aborting {question_id}")
return
api_result = api_task.result()
except InstanceFailedError:
logger.error(f"Instance failed, skipping {question_id}")
return
elapsed = time.monotonic() - t0
if api_result is None:
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response="",
extracted_answer=None,
gold_answer=gold,
correct=False,
error="API failure after retries",
elapsed_s=elapsed,
)
else:
response = api_result.content
stats = {
"prompt_tokens": api_result.prompt_tokens,
"completion_tokens": api_result.completion_tokens,
"reasoning_tokens": api_result.reasoning_tokens,
"reasoning_content": api_result.reasoning_content,
"finish_reason": api_result.finish_reason,
"elapsed_s": elapsed,
"power_watts": api_result.power_watts,
"energy_joules": api_result.energy_joules,
}
if config.kind == "mc":
extracted = extract_mc_answer(response, valid_letters)
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=extracted,
gold_answer=gold,
correct=(extracted == gold) if extracted else False,
**stats,
)
elif config.kind == "math":
extracted = extract_boxed_answer(response)
correct = (
check_aime_answer(extracted, int(gold)) if extracted else False
)
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=extracted,
gold_answer=gold,
correct=correct,
**stats,
)
elif config.kind == "code":
# HumanEval needs preserved indentation (function body completion)
keep_indent = benchmark_name == "humaneval"
code = extract_code_block(response, preserve_indent=keep_indent)
if code is None:
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
gold_answer=gold,
correct=False,
error="No code block extracted",
**stats,
)
elif benchmark_name == "humaneval":
passed, diag = run_humaneval_test(
exec_meta["problem"],
code,
)
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer="pass" if passed else "fail",
gold_answer=gold,
correct=passed,
error=diag if not passed else None,
**stats,
)
elif benchmark_name == "livecodebench":
passed, diag = run_livecodebench_test(
code,
exec_meta["sample"],
)
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer="pass" if passed else "fail",
gold_answer=gold,
correct=passed,
error=diag if not passed else None,
**stats,
)
else:
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
gold_answer=gold,
correct=False,
error="Unknown code benchmark",
**stats,
)
else:
result = QuestionResult(
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
gold_answer=gold,
correct=False,
error="Unsupported kind",
**stats,
)
results[idx] = result
# Write checkpoint (skip failures so they get retried on resume)
if checkpoint_path is not None and not result.error and result.response:
_write_checkpoint(checkpoint_path, result)
async with lock:
completed += 1
n = completed
# Log progress
thinking_info = ""
if result.reasoning_content:
thinking_info = f", {len(result.reasoning_content)} chars thinking"
logger.info(
f" [{n}/{total}] {question_id}: {len(result.response)} chars{thinking_info}, "
f"tokens: {result.prompt_tokens}+{result.completion_tokens} "
f"[{result.finish_reason}]"
+ (f" {result.extracted_answer}" if result.extracted_answer else "")
)
async def _health_monitor() -> None:
"""Periodically check if the instance is still alive."""
# Wait a bit before first check to let things start
await asyncio.sleep(10)
while not instance_failed.is_set():
if not await _check_instance_health(base_url):
# Double-check to avoid false positives
await asyncio.sleep(2)
if not await _check_instance_health(base_url):
logger.error("Health monitor: instance is down!")
instance_failed.set()
return
await asyncio.sleep(5)
async with httpx.AsyncClient() as http_client:
monitor = asyncio.create_task(_health_monitor())
tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)]
await asyncio.gather(*tasks)
monitor.cancel()
with contextlib.suppress(asyncio.CancelledError):
await monitor
if instance_failed.is_set():
completed_count = sum(1 for r in results if r is not None)
logger.error(
f"Instance failed! Completed {completed_count}/{total} problems. "
f"Checkpoint saved — restart to resume remaining problems."
)
raise InstanceFailedError("Instance failed during evaluation")
return [r for r in results if r is not None]
def _write_checkpoint(path: Path, result: QuestionResult) -> None:
"""Append a single result to the JSONL checkpoint file."""
entry = {
"question_id": result.question_id,
"prompt": result.prompt,
"response": result.response,
"extracted_answer": result.extracted_answer,
"gold_answer": result.gold_answer,
"correct": result.correct,
"error": result.error,
"prompt_tokens": result.prompt_tokens,
"completion_tokens": result.completion_tokens,
"reasoning_tokens": result.reasoning_tokens,
"reasoning_content": result.reasoning_content,
"finish_reason": result.finish_reason,
"elapsed_s": round(result.elapsed_s, 2),
"power_watts": round(result.power_watts, 2),
"energy_joules": round(result.energy_joules, 2),
}
with open(path, "a") as f:
f.write(json.dumps(entry) + "\n")
# ---------------------------------------------------------------------------
# Results display
# ---------------------------------------------------------------------------
def print_results(
benchmark_name: str,
results: list[QuestionResult],
concurrency: int | None = None,
) -> dict[str, Any]:
total = len(results)
correct = sum(r.correct for r in results)
errors = sum(1 for r in results if r.error)
no_extract = sum(1 for r in results if r.extracted_answer is None and not r.error)
accuracy = correct / max(total, 1)
total_prompt_tokens = sum(r.prompt_tokens for r in results)
total_completion_tokens = sum(r.completion_tokens for r in results)
total_reasoning_tokens = sum(r.reasoning_tokens for r in results)
total_elapsed = sum(r.elapsed_s for r in results)
wall_clock = max(r.elapsed_s for r in results) if results else 0.0
avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0
total_energy = sum(r.energy_joules for r in results)
avg_power = sum(r.power_watts for r in results) / max(total, 1)
label = f"[c={concurrency}] " if concurrency is not None else ""
print(f"\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})")
tok_line = f" tokens: {total_prompt_tokens:,} prompt + {total_completion_tokens:,} completion"
if total_reasoning_tokens > 0:
tok_line += f" ({total_reasoning_tokens:,} reasoning)"
tok_line += (
f" | avg gen tps: {avg_gen_tps:.1f}"
f" | total time: {total_elapsed:.1f}s wall clock: {wall_clock:.1f}s"
)
print(tok_line)
if total_energy > 0:
print(
f" power: avg {avg_power:.1f}W | total energy: {total_energy:.1f}J ({total_energy / 3600:.2f}Wh)"
)
if errors:
print(f" API errors: {errors}")
if no_extract:
print(f" No answer extracted: {no_extract}")
return {
"benchmark": benchmark_name,
"accuracy": accuracy,
"correct": correct,
"total": total,
"errors": errors,
"no_extract": no_extract,
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_reasoning_tokens": total_reasoning_tokens,
"total_elapsed_s": total_elapsed,
"wall_clock_s": wall_clock,
"avg_gen_tps": avg_gen_tps,
"avg_power_watts": avg_power,
"total_energy_joules": total_energy,
}
def print_comparison(
benchmark_name: str,
results_by_c: dict[int, list[QuestionResult]],
) -> None:
levels = sorted(results_by_c.keys())
print(f"\n{'=' * 70}")
print(f"COMPARISON: {benchmark_name}")
print(f"{'=' * 70}")
header = f"{'Concurrency':<15} {'Accuracy':>10} {'Correct':>10} {'Total':>10} {'Comp Tokens':>12} {'Wall Clock':>12} {'Avg Gen TPS':>12}"
print(header)
print("-" * len(header))
for c in levels:
r = results_by_c[c]
correct = sum(q.correct for q in r)
total = len(r)
comp_tok = sum(q.completion_tokens for q in r)
total_elapsed = sum(q.elapsed_s for q in r)
avg_tps = comp_tok / total_elapsed if total_elapsed > 0 else 0.0
wall = max(q.elapsed_s for q in r) if r else 0.0
print(
f"c={c:<13} {correct / max(total, 1):>10.1%} {correct:>10} {total:>10}"
f" {comp_tok:>12,} {wall:>11.1f}s {avg_tps:>12.1f}"
)
if len(levels) >= 2:
base_results = results_by_c[levels[0]]
test_results = results_by_c[levels[-1]]
changed = sum(
1
for br, tr in zip(base_results, test_results, strict=True)
if br.correct != tr.correct
)
total = min(len(base_results), len(test_results))
print(
f"\nQuestions with different correctness (c={levels[0]} vs c={levels[-1]}): {changed}/{total}"
)
if changed == 0:
print("Batching produced identical quality.")
elif changed <= total * 0.01:
print("Negligible quality difference from batching.")
else:
print(
f"WARNING: {changed / max(total, 1) * 100:.1f}% of questions changed."
)
print()
# ---------------------------------------------------------------------------
# Interactive task picker
# ---------------------------------------------------------------------------
def pick_tasks_interactive() -> list[str]:
import termios
import tty
if not sys.stdin.isatty():
logger.error("No --tasks specified and stdin is not a terminal.")
return []
items = [(name, cfg.description) for name, cfg in BENCHMARKS.items()]
selected: list[bool] = [False] * len(items)
cursor = 0
total_lines = len(items) + 4
def write(s: str) -> None:
sys.stdout.write(s)
def render(first: bool = False) -> None:
if not first:
write(f"\033[{total_lines}A")
write("\033[J")
write(
"\033[1mSelect benchmarks\033[0m (up/down, space toggle, enter confirm, q quit)\r\n\r\n"
)
for i, (name, desc) in enumerate(items):
marker = ">" if i == cursor else " "
check = "x" if selected[i] else " "
line = f" {marker} [{check}] {name:<17} {desc}"
write(f"\033[7m{line}\033[0m\r\n" if i == cursor else f"{line}\r\n")
write(f"\r\n {sum(selected)} selected\r\n")
sys.stdout.flush()
fd = sys.stdin.fileno()
old = termios.tcgetattr(fd)
try:
tty.setraw(fd)
write("\033[?25l")
render(first=True)
while True:
ch = sys.stdin.read(1)
if ch in ("q", "\x03"):
write("\033[?25h\033[0m\r\n")
return []
elif ch in ("\r", "\n"):
break
elif ch == " ":
selected[cursor] = not selected[cursor]
elif ch == "\x1b":
seq = sys.stdin.read(2)
if seq == "[A":
cursor = (cursor - 1) % len(items)
elif seq == "[B":
cursor = (cursor + 1) % len(items)
render()
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old)
write("\033[?25h\033[0m\r\n")
sys.stdout.flush()
chosen = [name for (name, _), sel in zip(items, selected, strict=True) if sel]
if chosen:
logger.info(f"Selected: {', '.join(chosen)}")
return chosen
# ---------------------------------------------------------------------------
# Results persistence
# ---------------------------------------------------------------------------
def save_results(
results_dir: str,
benchmark_name: str,
model: str,
concurrency: int,
results: list[QuestionResult],
scores: dict[str, Any],
cluster: dict[str, Any] | None = None,
) -> Path:
out_dir = Path(results_dir) / model.replace("/", "_") / benchmark_name
out_dir.mkdir(parents=True, exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
path = out_dir / f"c{concurrency}_{ts}.json"
data: dict[str, Any] = {
"benchmark": benchmark_name,
"model": model,
"concurrency": concurrency,
**({"cluster": cluster} if cluster else {}),
"scores": scores,
"results": [
{
"question_id": r.question_id,
"prompt": r.prompt,
"response": r.response,
"extracted_answer": r.extracted_answer,
"gold_answer": r.gold_answer,
"correct": r.correct,
"error": r.error,
"prompt_tokens": r.prompt_tokens,
"completion_tokens": r.completion_tokens,
"reasoning_tokens": r.reasoning_tokens,
"reasoning_content": r.reasoning_content,
"finish_reason": r.finish_reason,
"elapsed_s": round(r.elapsed_s, 2),
"power_watts": round(r.power_watts, 2),
"energy_joules": round(r.energy_joules, 2),
}
for r in results
],
}
with open(path, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Results saved to {path}")
return path
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _checkpoint_path(
results_dir: str, benchmark: str, model: str, concurrency: int
) -> Path:
"""Return the JSONL checkpoint path for a benchmark run."""
out_dir = Path(results_dir) / model.replace("/", "_") / benchmark
out_dir.mkdir(parents=True, exist_ok=True)
return out_dir / f"c{concurrency}.checkpoint.jsonl"
def parse_int_list(values: list[str]) -> list[int]:
items: list[int] = []
for v in values:
for part in v.split(","):
if part.strip():
items.append(int(part.strip()))
return items
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Quality evaluation for exo — matches Artificial Analysis methodology.",
)
add_common_instance_args(ap)
ap.add_argument(
"--tasks",
default=None,
help="Comma-separated benchmark names. Omit for interactive picker.",
)
ap.add_argument(
"--limit",
type=int,
default=None,
help="Max questions per benchmark (for fast iteration).",
)
ap.add_argument(
"--offset",
type=int,
default=0,
help="Skip first N questions (0-based).",
)
reasoning_group = ap.add_mutually_exclusive_group()
reasoning_group.add_argument(
"--reasoning",
action="store_true",
default=None,
help="Force reasoning-model settings (temperature=0.6, max_tokens=65536).",
)
reasoning_group.add_argument(
"--no-reasoning",
action="store_true",
default=False,
help="Force non-reasoning settings (temperature=0, max_tokens=16384).",
)
ap.add_argument(
"--temperature", type=float, default=None, help="Override temperature."
)
ap.add_argument("--top-p", type=float, default=None, help="Override top_p.")
ap.add_argument("--top-k", type=int, default=None, help="Override top_k.")
ap.add_argument("--min-p", type=float, default=None, help="Override min_p.")
ap.add_argument(
"--max-tokens", type=int, default=None, help="Override max output tokens."
)
ap.add_argument(
"--num-concurrent",
type=int,
default=1,
help="Concurrent API requests (default: 1).",
)
ap.add_argument(
"--compare-concurrency",
nargs="+",
default=None,
help="Run at multiple concurrency levels and compare. E.g. --compare-concurrency 1,4",
)
ap.add_argument(
"--request-timeout",
type=float,
default=None,
help="Per-request timeout in seconds (default: no timeout).",
)
ap.add_argument(
"--reasoning-effort",
default=None,
choices=["low", "medium", "high"],
help="Override reasoning effort (default: 'high' for reasoning models, none for non-reasoning).",
)
ap.add_argument(
"--difficulty",
default=None,
choices=["easy", "medium", "hard"],
help="Filter by difficulty (livecodebench only). E.g. --difficulty hard",
)
ap.add_argument(
"--release-version",
default=None,
help="LCB dataset release version (livecodebench only). E.g. release_v5",
)
ap.add_argument(
"--results-dir",
default="eval_results",
help="Directory for result JSON files (default: eval_results).",
)
ap.add_argument(
"--enable-thinking",
type=lambda v: v.lower() in ("true", "1", "yes"),
default=None,
help="Enable thinking mode for models that support it.",
)
ap.add_argument(
"--force",
action="store_true",
help="Discard any existing checkpoint and run from scratch.",
)
ap.add_argument(
"--keep-instance",
action="store_true",
help="Skip deleting the instance after eval (for chaining runs).",
)
args, _ = ap.parse_known_args()
# Resolve tasks
if args.tasks:
task_names = [t.strip() for t in args.tasks.split(",") if t.strip()]
else:
task_names = pick_tasks_interactive()
if not task_names:
return 0
for t in task_names:
if t not in BENCHMARKS:
logger.error(f"Unknown benchmark '{t}'. Available: {', '.join(BENCHMARKS)}")
return 1
# Instance management
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
instance_id: str | None = None
created_instance = False
_short_id, full_model_id = resolve_model_short_id(
client,
args.model,
force_download=args.force_download,
)
# Auto-detect a running instance for this model
if not args.fresh_instance:
existing = find_existing_instance(client, full_model_id)
if existing:
instance_id = existing
logger.info(f"Reusing existing instance {instance_id}")
if instance_id is None:
selected = settle_and_fetch_placements(
client,
full_model_id,
args,
settle_timeout=args.settle_timeout,
)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
preview = selected[0]
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
logger.info(
f"PLACEMENT: {preview['sharding']} / {preview['instance_meta']} / "
f"nodes={nodes_used_in_instance(instance)}"
)
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
download_duration = run_planning_phase(
client,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if download_duration is not None:
logger.info(f"Download: {download_duration:.1f}s")
# Delete any existing instances to free resources before placing
try:
state = client.request_json("GET", "/state")
for old_id in list(state.get("instances", {}).keys()):
logger.info(f"Deleting stale instance {old_id}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{old_id}")
if state.get("instances"):
time.sleep(2)
except Exception as e:
logger.warning(f"Failed to clean up stale instances: {e}")
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
return 1
time.sleep(1)
created_instance = True
cluster_snapshot = capture_cluster_snapshot(client)
# Auto-detect reasoning from model config
model_config = load_model_config(full_model_id)
if args.reasoning:
is_reasoning = True
elif args.no_reasoning:
is_reasoning = False
elif model_config is not None:
is_reasoning = model_config.get("reasoning", False)
logger.info(
f"Auto-detected from config: {model_config['name']}"
f"{'reasoning' if is_reasoning else 'non-reasoning'}"
)
else:
is_reasoning = False
logger.warning(
f"Model '{full_model_id}' not found in eval_configs/models.toml. "
f"Defaulting to non-reasoning. Use --reasoning to override."
)
# Resolve temperature, max_tokens, reasoning_effort
# Priority: CLI flag > per-model config > global defaults
cfg = model_config or {}
if args.temperature is not None:
temperature = args.temperature
elif "temperature" in cfg:
temperature = float(cfg["temperature"])
else:
temperature = (
TEMPERATURE_REASONING if is_reasoning else TEMPERATURE_NON_REASONING
)
if args.max_tokens is not None:
max_tokens = args.max_tokens
elif "max_tokens" in cfg:
max_tokens = int(cfg["max_tokens"])
else:
max_tokens = REASONING_MAX_TOKENS if is_reasoning else DEFAULT_MAX_TOKENS
if args.top_p is not None:
top_p: float | None = args.top_p
elif "top_p" in cfg:
top_p = float(cfg["top_p"])
else:
top_p = None # let server use its default
if args.reasoning_effort is not None:
reasoning_effort = args.reasoning_effort
elif "reasoning_effort" in cfg:
reasoning_effort = str(cfg["reasoning_effort"])
else:
reasoning_effort = "high" if is_reasoning else None
if args.top_k is not None:
top_k: int | None = args.top_k
elif "top_k" in cfg:
top_k = int(cfg["top_k"])
else:
top_k = None
if args.min_p is not None:
min_p: float | None = args.min_p
elif "min_p" in cfg:
min_p = float(cfg["min_p"])
else:
min_p = None
if args.enable_thinking is not None:
enable_thinking: bool | None = args.enable_thinking
elif "enable_thinking" in cfg:
enable_thinking = bool(cfg["enable_thinking"])
else:
enable_thinking = None
base_url = f"http://{args.host}:{args.port}"
logger.info(f"Model: {full_model_id}")
logger.info(
f"Settings: temperature={temperature}, max_tokens={max_tokens}, "
+ (f"top_p={top_p}, " if top_p is not None else "")
+ (f"top_k={top_k}, " if top_k is not None else "")
+ (f"min_p={min_p}, " if min_p is not None else "")
+ f"reasoning={'yes' if is_reasoning else 'no'}"
+ (f", reasoning_effort={reasoning_effort}" if reasoning_effort else "")
+ (
f", enable_thinking={enable_thinking}"
if enable_thinking is not None
else ""
)
)
# Common kwargs for evaluate_benchmark
eval_kwargs: dict[str, Any] = {
"reasoning_effort": reasoning_effort,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"enable_thinking": enable_thinking,
"difficulty": args.difficulty,
"offset": args.offset,
"release_version": args.release_version,
}
try:
if args.compare_concurrency:
concurrency_levels = parse_int_list(args.compare_concurrency)
for task_name in task_names:
results_by_c: dict[int, list[QuestionResult]] = {}
for c in concurrency_levels:
logger.info(f"\n{'=' * 50}")
logger.info(f"Running {task_name} at concurrency={c}")
checkpoint_path = _checkpoint_path(
args.results_dir, task_name, full_model_id, c
)
if args.force and checkpoint_path.exists():
checkpoint_path.unlink()
results = asyncio.run(
evaluate_benchmark(
task_name,
base_url,
full_model_id,
temperature,
max_tokens,
concurrency=c,
limit=args.limit,
timeout=args.request_timeout,
checkpoint_path=checkpoint_path,
**eval_kwargs,
)
)
if results:
scores = print_results(task_name, results, concurrency=c)
save_results(
args.results_dir,
task_name,
full_model_id,
c,
results,
scores,
cluster=cluster_snapshot,
)
results_by_c[c] = results
# Clean up checkpoint on success
if checkpoint_path.exists():
checkpoint_path.unlink()
if len(results_by_c) >= 2:
print_comparison(task_name, results_by_c)
else:
for task_name in task_names:
checkpoint_path = _checkpoint_path(
args.results_dir, task_name, full_model_id, args.num_concurrent
)
if args.force and checkpoint_path.exists():
checkpoint_path.unlink()
results = asyncio.run(
evaluate_benchmark(
task_name,
base_url,
full_model_id,
temperature,
max_tokens,
concurrency=args.num_concurrent,
limit=args.limit,
timeout=args.request_timeout,
checkpoint_path=checkpoint_path,
**eval_kwargs,
)
)
if results:
scores = print_results(task_name, results)
save_results(
args.results_dir,
task_name,
full_model_id,
args.num_concurrent,
results,
scores,
cluster=cluster_snapshot,
)
# Clean up checkpoint on success
if checkpoint_path.exists():
checkpoint_path.unlink()
finally:
if created_instance and instance_id is not None:
if args.keep_instance:
logger.info(f"Keeping instance {instance_id} (--keep-instance)")
else:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
try:
wait_for_instance_gone(client, instance_id)
except TimeoutError:
logger.warning(
f"Timed out waiting for instance {instance_id} to be deleted"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())