mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-17 20:40:35 -04:00
<img width="3224" height="1476" alt="image" src="https://github.com/user-attachments/assets/d90a7d8a-9fe5-43a1-a715-1ef7ecc15422" />
1383 lines
46 KiB
Python
1383 lines
46 KiB
Python
# type: ignore
|
|
#!/usr/bin/env python3
|
|
"""Quality evaluation for exo — matches Artificial Analysis methodology.
|
|
|
|
Runs LLM benchmarks against exo's OpenAI-compatible API using the same
|
|
prompts, temperature settings, and answer extraction as Artificial Analysis.
|
|
|
|
Supported benchmarks:
|
|
gpqa_diamond - Graduate-level science QA (198 questions, 4-choice MC)
|
|
mmlu_pro - Multi-task language understanding (12K questions, 10-choice MC)
|
|
aime_2024 - Math olympiad 2024 (30 problems, integer answers)
|
|
aime_2025 - Math olympiad 2025 (30 problems, integer answers)
|
|
humaneval - Python code generation (164 problems, pass@1)
|
|
livecodebench - Competitive programming (880+ problems, pass@1)
|
|
|
|
Model configs in eval_configs/models.toml auto-detect reasoning/non-reasoning
|
|
settings per model. Override with --reasoning / --no-reasoning.
|
|
|
|
Usage:
|
|
uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond
|
|
uv run python exo_eval.py --model <model-id> --tasks humaneval,livecodebench --limit 50
|
|
uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond --compare-concurrency 1,4
|
|
|
|
References:
|
|
https://artificialanalysis.ai/methodology/intelligence-benchmarking
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import contextlib
|
|
import json
|
|
import multiprocessing
|
|
import random
|
|
import re
|
|
import sys
|
|
import time
|
|
import tomllib
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from harness import (
|
|
ExoClient,
|
|
ExoHttpError,
|
|
add_common_instance_args,
|
|
capture_cluster_snapshot,
|
|
instance_id_from_instance,
|
|
nodes_used_in_instance,
|
|
resolve_model_short_id,
|
|
run_planning_phase,
|
|
settle_and_fetch_placements,
|
|
wait_for_instance_gone,
|
|
wait_for_instance_ready,
|
|
)
|
|
from loguru import logger
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Artificial Analysis constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
MAX_RETRIES = 30
|
|
DEFAULT_MAX_TOKENS = 16_384
|
|
REASONING_MAX_TOKENS = 131_072
|
|
TEMPERATURE_NON_REASONING = 0.0
|
|
TEMPERATURE_REASONING = 1.0
|
|
|
|
# MC answer extraction: 8 fallback regex patterns.
|
|
# All patterns are tried; the match at the latest text position wins
|
|
# (handles models that self-correct during reasoning).
|
|
_MC_PATTERNS: list[re.Pattern[str]] = [
|
|
re.compile(
|
|
r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])"
|
|
),
|
|
re.compile(r"\\boxed\{[^}]*([A-Z])[^}]*\}"),
|
|
re.compile(r"(?i)answer is ([a-zA-Z])"),
|
|
re.compile(r"(?i)answer is \\\(([a-zA-Z])"),
|
|
re.compile(r"([A-Z])\)\s*[^A-Z]*$"),
|
|
re.compile(r"([A-Z])\s+is\s+the\s+correct\s+answer"),
|
|
re.compile(r"([A-Z])\s*$"),
|
|
re.compile(r"([A-Z])\s*\."),
|
|
]
|
|
|
|
# Code extraction: last ```python ... ``` block (AA regex)
|
|
_CODE_BLOCK_RE = re.compile(r"```(?:python|Python)?\s*\n(.*?)```", re.DOTALL)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model config loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def load_model_config(model_id: str) -> dict[str, Any] | None:
|
|
"""Look up model in eval_configs/models.toml. Returns config dict or None."""
|
|
config_path = Path(__file__).resolve().parent / "eval_configs" / "models.toml"
|
|
if not config_path.exists():
|
|
return None
|
|
with open(config_path, "rb") as f:
|
|
data = tomllib.load(f)
|
|
for entry in data.get("model", []):
|
|
patterns = entry.get("patterns", [])
|
|
if any(p in model_id for p in patterns):
|
|
return entry
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Answer extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def extract_mc_answer(text: str, valid_letters: str = "ABCD") -> str | None:
|
|
"""Extract MC answer. Last match by text position wins."""
|
|
valid_set = set(valid_letters)
|
|
best: tuple[int, str] | None = None
|
|
for pattern in _MC_PATTERNS:
|
|
for m in pattern.finditer(text):
|
|
letter = m.group(1).upper()
|
|
if letter in valid_set:
|
|
pos = m.start()
|
|
if best is None or pos >= best[0]:
|
|
best = (pos, letter)
|
|
return best[1] if best else None
|
|
|
|
|
|
def extract_boxed_answer(text: str) -> str | None:
|
|
r"""Extract content from the last \boxed{...}."""
|
|
matches: list[str] = []
|
|
idx = 0
|
|
while True:
|
|
pos = text.find("\\boxed{", idx)
|
|
if pos < 0:
|
|
break
|
|
depth = 0
|
|
i = pos + len("\\boxed{")
|
|
start = i
|
|
while i < len(text):
|
|
if text[i] == "{":
|
|
depth += 1
|
|
elif text[i] == "}":
|
|
if depth == 0:
|
|
matches.append(text[start:i])
|
|
break
|
|
depth -= 1
|
|
i += 1
|
|
idx = i + 1 if i < len(text) else len(text)
|
|
return matches[-1].strip() if matches else None
|
|
|
|
|
|
def extract_code_block(text: str, preserve_indent: bool = False) -> str | None:
|
|
"""Extract the last Python code block from markdown response.
|
|
|
|
If preserve_indent is True, only strip trailing whitespace (keeps leading
|
|
indentation intact — needed for HumanEval function-body completions).
|
|
"""
|
|
matches = _CODE_BLOCK_RE.findall(text)
|
|
if matches:
|
|
raw = matches[-1]
|
|
return raw.rstrip() if preserve_indent else raw.strip()
|
|
# Fallback: try raw code after last ```
|
|
lines = text.split("\n")
|
|
backtick_lines = [i for i, line in enumerate(lines) if "```" in line]
|
|
if len(backtick_lines) >= 2:
|
|
return "\n".join(lines[backtick_lines[-2] + 1 : backtick_lines[-1]])
|
|
return None
|
|
|
|
|
|
def check_aime_answer(extracted: str, gold: int) -> bool:
|
|
"""Check if extracted AIME answer matches gold integer."""
|
|
try:
|
|
return int(extracted.strip()) == gold
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
from math_verify import parse, verify
|
|
|
|
return verify(parse(str(gold)), parse(extracted))
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Code execution — official evaluation harnesses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# LiveCodeBench: vendored from https://github.com/LiveCodeBench/LiveCodeBench
|
|
# run_test() must execute in a child process because reliability_guard()
|
|
# permanently disables OS functions (os.kill, subprocess.Popen, etc.).
|
|
|
|
|
|
def _lcb_worker(
|
|
sample: dict,
|
|
code: str,
|
|
timeout: int,
|
|
result_holder: list[Any],
|
|
metadata_holder: list[Any],
|
|
) -> None:
|
|
"""Target for multiprocessing.Process — runs vendored LCB run_test."""
|
|
from vendor.lcb_testing_util import run_test
|
|
|
|
try:
|
|
results, metadata = run_test(sample, test=code, debug=False, timeout=timeout)
|
|
result_holder.append(results)
|
|
metadata_holder.append(metadata)
|
|
except Exception as e:
|
|
result_holder.append([-4])
|
|
metadata_holder.append({"error_code": -4, "error_message": str(e)})
|
|
|
|
|
|
def run_livecodebench_test(
|
|
code: str,
|
|
sample: dict,
|
|
timeout: int = 6,
|
|
) -> tuple[bool, str]:
|
|
"""Run LCB evaluation in a subprocess. Returns (passed, diagnostic_info)."""
|
|
manager = multiprocessing.Manager()
|
|
result_holder = manager.list()
|
|
metadata_holder = manager.list()
|
|
|
|
proc = multiprocessing.Process(
|
|
target=_lcb_worker,
|
|
args=(sample, code, timeout, result_holder, metadata_holder),
|
|
)
|
|
proc.start()
|
|
|
|
# Global timeout: (per-test timeout + 1) * num_tests + 5
|
|
num_tests = len(json.loads(sample["input_output"]).get("inputs", []))
|
|
global_timeout = (timeout + 1) * num_tests + 5
|
|
proc.join(timeout=global_timeout)
|
|
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
proc.join()
|
|
return False, "Global timeout exceeded"
|
|
|
|
if not result_holder:
|
|
return False, "No results returned from worker"
|
|
|
|
results = list(result_holder[0])
|
|
metadata = dict(metadata_holder[0]) if metadata_holder else {}
|
|
|
|
# LCB convention: True = pass, negative int = failure code
|
|
all_passed = all(r is True or r == 1 for r in results)
|
|
if all_passed:
|
|
return True, ""
|
|
|
|
diag = metadata.get("error_message", "")
|
|
if not diag and "output" in metadata:
|
|
diag = f"Got {metadata['output']}, expected {metadata.get('expected', '?')}"
|
|
return False, diag
|
|
|
|
|
|
def run_humaneval_test(
|
|
problem: dict, completion: str, timeout: float = 10.0
|
|
) -> tuple[bool, str]:
|
|
"""Run HumanEval evaluation using the official human_eval package."""
|
|
from human_eval.execution import check_correctness
|
|
|
|
result = check_correctness(problem, completion, timeout)
|
|
passed = result["passed"]
|
|
diag = "" if passed else result.get("result", "failed")
|
|
return passed, diag
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Benchmark definitions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class QuestionResult:
|
|
question_id: int
|
|
prompt: str
|
|
response: str
|
|
extracted_answer: str | None
|
|
gold_answer: str
|
|
correct: bool
|
|
error: str | None = None
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
reasoning_tokens: int = 0
|
|
elapsed_s: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkConfig:
|
|
name: str
|
|
description: str
|
|
dataset_name: str
|
|
dataset_config: str | None
|
|
split: str
|
|
kind: str # "mc", "math", "code"
|
|
|
|
|
|
BENCHMARKS: dict[str, BenchmarkConfig] = {
|
|
"gpqa_diamond": BenchmarkConfig(
|
|
name="gpqa_diamond",
|
|
description="Graduate-level science QA (198 Q, 4-choice MC)",
|
|
dataset_name="Idavidrein/gpqa",
|
|
dataset_config="gpqa_diamond",
|
|
split="train",
|
|
kind="mc",
|
|
),
|
|
"mmlu_pro": BenchmarkConfig(
|
|
name="mmlu_pro",
|
|
description="Multi-task language understanding (12K Q, 10-choice MC)",
|
|
dataset_name="TIGER-Lab/MMLU-Pro",
|
|
dataset_config=None,
|
|
split="test",
|
|
kind="mc",
|
|
),
|
|
"aime_2024": BenchmarkConfig(
|
|
name="aime_2024",
|
|
description="Math olympiad 2024 (30 problems, integer answers)",
|
|
dataset_name="HuggingFaceH4/aime_2024",
|
|
dataset_config=None,
|
|
split="train",
|
|
kind="math",
|
|
),
|
|
"aime_2025": BenchmarkConfig(
|
|
name="aime_2025",
|
|
description="Math olympiad 2025 (30 problems, integer answers)",
|
|
dataset_name="MathArena/aime_2025",
|
|
dataset_config=None,
|
|
split="train",
|
|
kind="math",
|
|
),
|
|
"humaneval": BenchmarkConfig(
|
|
name="humaneval",
|
|
description="Python code generation (164 problems, pass@1)",
|
|
dataset_name="openai/openai_humaneval",
|
|
dataset_config=None,
|
|
split="test",
|
|
kind="code",
|
|
),
|
|
"livecodebench": BenchmarkConfig(
|
|
name="livecodebench",
|
|
description="Competitive programming (880+ problems, pass@1)",
|
|
dataset_name="livecodebench/code_generation_lite",
|
|
dataset_config=None,
|
|
split="test",
|
|
kind="code",
|
|
),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt formatters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_GPQA_INSTRUCTION = (
|
|
"Answer the following multiple choice question. "
|
|
"The last line of your response should be in the following format: "
|
|
"'Answer: A/B/C/D' (e.g. 'Answer: A')."
|
|
)
|
|
|
|
_MMLU_PRO_INSTRUCTION = (
|
|
"Answer the following multiple choice question. "
|
|
"The last line of your response should be in the following format: "
|
|
"'Answer: A/B/C/D/E/F/G/H/I/J' (e.g. 'Answer: A')."
|
|
)
|
|
|
|
_AIME_INSTRUCTION = (
|
|
"Solve the following math problem step by step. "
|
|
"Put your answer inside \\boxed{}.\n"
|
|
"Remember to put your answer inside \\boxed{}."
|
|
)
|
|
|
|
_HUMANEVAL_INSTRUCTION = (
|
|
"Complete the following Python function. Return only the function body "
|
|
"inside a ```python code block. Do not include the function signature."
|
|
)
|
|
|
|
# LiveCodeBench: AA uses original prompts without custom system prompts
|
|
_LCB_SYSTEM = (
|
|
"You are an expert Python programmer. You will be given a question "
|
|
"(problem specification) and will generate a correct Python program "
|
|
"that matches the specification and passes all tests."
|
|
)
|
|
|
|
_LCB_WITH_STARTER = (
|
|
"### Question:\n{question}\n\n"
|
|
"### Format: You will use the following starter code to write the "
|
|
"solution to the problem and enclose your code within delimiters.\n"
|
|
"```python\n{starter_code}\n```\n\n"
|
|
"### Answer: (use the provided format with backticks)\n"
|
|
)
|
|
|
|
_LCB_WITHOUT_STARTER = (
|
|
"### Question:\n{question}\n\n"
|
|
"### Format: Read the inputs from stdin solve the problem and write "
|
|
"the answer to stdout (do not directly test on the sample inputs). "
|
|
"Enclose your code within delimiters as follows. Ensure that when the "
|
|
"python program runs, it reads the inputs, runs the algorithm and "
|
|
"writes output to STDOUT.\n"
|
|
"```python\n# YOUR CODE HERE\n```\n\n"
|
|
"### Answer: (use the provided format with backticks)\n"
|
|
)
|
|
|
|
|
|
def format_gpqa_question(doc: dict, idx: int) -> tuple[str, str]:
|
|
"""Returns (prompt, correct_letter)."""
|
|
correct = doc["Correct Answer"]
|
|
choices = [
|
|
correct,
|
|
doc["Incorrect Answer 1"],
|
|
doc["Incorrect Answer 2"],
|
|
doc["Incorrect Answer 3"],
|
|
]
|
|
rng = random.Random(idx)
|
|
order = rng.sample(range(4), 4)
|
|
shuffled = [choices[i] for i in order]
|
|
correct_letter = "ABCD"[order.index(0)]
|
|
choices_text = "\n".join(f"{L}) {shuffled[i]}" for i, L in enumerate("ABCD"))
|
|
return f"{_GPQA_INSTRUCTION}\n\n{doc['Question']}\n\n{choices_text}", correct_letter
|
|
|
|
|
|
def format_mmlu_pro_question(doc: dict) -> tuple[str, str]:
|
|
"""Returns (prompt, correct_letter)."""
|
|
options = doc["options"]
|
|
letters = "ABCDEFGHIJ"
|
|
choices_text = "\n".join(f"{letters[i]}) {opt}" for i, opt in enumerate(options))
|
|
return f"{_MMLU_PRO_INSTRUCTION}\n\n{doc['question']}\n\n{choices_text}", doc[
|
|
"answer"
|
|
]
|
|
|
|
|
|
def format_aime_question(doc: dict) -> tuple[str, int]:
|
|
"""Returns (prompt, correct_answer_int)."""
|
|
return f"{_AIME_INSTRUCTION}\n\n{doc['problem']}", int(doc["answer"])
|
|
|
|
|
|
def format_humaneval_question(doc: dict) -> tuple[str, dict]:
|
|
"""Returns (prompt, metadata_for_execution)."""
|
|
prompt = f"{_HUMANEVAL_INSTRUCTION}\n\n```python\n{doc['prompt']}```"
|
|
# Pass the full problem dict — check_correctness needs task_id, prompt,
|
|
# test, entry_point
|
|
meta = {
|
|
"problem": {
|
|
"task_id": doc["task_id"],
|
|
"prompt": doc["prompt"],
|
|
"test": doc["test"],
|
|
"entry_point": doc["entry_point"],
|
|
},
|
|
}
|
|
return prompt, meta
|
|
|
|
|
|
def format_livecodebench_question(doc: dict) -> tuple[str, str | None, dict]:
|
|
"""Returns (prompt, system_message, metadata_for_execution)."""
|
|
starter_code = doc.get("starter_code", "")
|
|
question_content = doc["question_content"]
|
|
|
|
if starter_code and starter_code.strip():
|
|
user_msg = _LCB_WITH_STARTER.format(
|
|
question=question_content, starter_code=starter_code
|
|
)
|
|
else:
|
|
user_msg = _LCB_WITHOUT_STARTER.format(question=question_content)
|
|
|
|
# Parse test cases
|
|
public_tests = (
|
|
json.loads(doc["public_test_cases"])
|
|
if isinstance(doc["public_test_cases"], str)
|
|
else doc["public_test_cases"]
|
|
)
|
|
private_tests = doc.get("private_test_cases", "[]")
|
|
if isinstance(private_tests, str):
|
|
try:
|
|
private_tests = json.loads(private_tests)
|
|
except Exception:
|
|
import base64
|
|
import pickle
|
|
import zlib
|
|
|
|
private_tests = json.loads(
|
|
pickle.loads(
|
|
zlib.decompress(base64.b64decode(private_tests.encode("utf-8")))
|
|
)
|
|
)
|
|
|
|
all_tests = public_tests + (
|
|
private_tests if isinstance(private_tests, list) else []
|
|
)
|
|
test_inputs = [t["input"] for t in all_tests]
|
|
test_outputs = [t["output"] for t in all_tests]
|
|
|
|
metadata = doc.get("metadata", "{}")
|
|
if isinstance(metadata, str):
|
|
metadata = json.loads(metadata)
|
|
func_name = metadata.get("func_name")
|
|
|
|
# Build the sample dict in official LCB format for run_test()
|
|
input_output: dict[str, Any] = {
|
|
"inputs": test_inputs,
|
|
"outputs": test_outputs,
|
|
}
|
|
if func_name:
|
|
input_output["fn_name"] = func_name
|
|
|
|
meta = {
|
|
"sample": {"input_output": json.dumps(input_output)},
|
|
}
|
|
return user_msg, _LCB_SYSTEM, meta
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# API client with retries
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class ApiResult:
|
|
content: str
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
reasoning_tokens: int
|
|
|
|
|
|
async def _call_api(
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
model: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
timeout: float | None,
|
|
system_message: str | None = None,
|
|
reasoning_effort: str | None = None,
|
|
top_p: float | None = None,
|
|
) -> ApiResult:
|
|
messages = []
|
|
if system_message:
|
|
messages.append({"role": "system", "content": system_message})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if reasoning_effort is not None:
|
|
body["reasoning_effort"] = reasoning_effort
|
|
if top_p is not None:
|
|
body["top_p"] = top_p
|
|
|
|
resp = await client.post(
|
|
f"{base_url}/v1/chat/completions",
|
|
json=body,
|
|
timeout=timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
content = data["choices"][0]["message"]["content"]
|
|
if not content or not content.strip():
|
|
raise ValueError("Empty response from model")
|
|
usage = data.get("usage", {})
|
|
details = usage.get("completion_tokens_details", {})
|
|
return ApiResult(
|
|
content=content,
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
reasoning_tokens=details.get("reasoning_tokens", 0) if details else 0,
|
|
)
|
|
|
|
|
|
async def call_with_retries(
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
model: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
timeout: float | None = None,
|
|
system_message: str | None = None,
|
|
reasoning_effort: str | None = None,
|
|
top_p: float | None = None,
|
|
) -> ApiResult | None:
|
|
for attempt in range(MAX_RETRIES):
|
|
try:
|
|
return await _call_api(
|
|
client,
|
|
base_url,
|
|
model,
|
|
prompt,
|
|
temperature,
|
|
max_tokens,
|
|
timeout,
|
|
system_message,
|
|
reasoning_effort,
|
|
top_p,
|
|
)
|
|
except Exception as e:
|
|
if attempt < MAX_RETRIES - 1:
|
|
wait = min(2**attempt, 60)
|
|
logger.warning(
|
|
f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}. Retrying in {wait}s..."
|
|
)
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
logger.error(f"All {MAX_RETRIES} retries exhausted. Last error: {e}")
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evaluation runners
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def evaluate_benchmark(
|
|
benchmark_name: str,
|
|
base_url: str,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
concurrency: int = 1,
|
|
limit: int | None = None,
|
|
timeout: float | None = None,
|
|
reasoning_effort: str | None = None,
|
|
top_p: float | None = None,
|
|
difficulty: str | None = None,
|
|
) -> list[QuestionResult]:
|
|
"""Run a benchmark. Returns per-question results."""
|
|
import datasets
|
|
|
|
config = BENCHMARKS[benchmark_name]
|
|
logger.info(f"Loading dataset {config.dataset_name}...")
|
|
|
|
try:
|
|
if benchmark_name == "livecodebench":
|
|
ds = datasets.load_dataset(
|
|
"json",
|
|
data_files="hf://datasets/livecodebench/code_generation_lite/*.jsonl",
|
|
split="train",
|
|
)
|
|
else:
|
|
ds = datasets.load_dataset(
|
|
config.dataset_name,
|
|
config.dataset_config,
|
|
split=config.split,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load dataset: {e}")
|
|
if "gated" in str(e).lower() or "login" in str(e).lower():
|
|
logger.error("Dataset requires authentication. Run: huggingface-cli login")
|
|
return []
|
|
|
|
if difficulty and "difficulty" in ds.column_names:
|
|
ds = ds.filter(lambda x: x["difficulty"] == difficulty)
|
|
logger.info(f"Filtered to {len(ds)} {difficulty} problems")
|
|
|
|
total = len(ds)
|
|
if limit and limit < total:
|
|
ds = ds.select(range(limit))
|
|
total = limit
|
|
|
|
logger.info(
|
|
f"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, "
|
|
f"temperature={temperature}, max_tokens={max_tokens}"
|
|
)
|
|
|
|
if config.kind == "code":
|
|
logger.warning(
|
|
"Code benchmarks execute model-generated code. Use a sandboxed environment."
|
|
)
|
|
|
|
semaphore = asyncio.Semaphore(concurrency)
|
|
results: list[QuestionResult | None] = [None] * total
|
|
completed = 0
|
|
lock = asyncio.Lock()
|
|
|
|
async def process_question(
|
|
idx: int, doc: dict, http_client: httpx.AsyncClient
|
|
) -> None:
|
|
nonlocal completed
|
|
system_msg = None
|
|
|
|
if benchmark_name == "gpqa_diamond":
|
|
prompt, gold = format_gpqa_question(doc, idx)
|
|
valid_letters = "ABCD"
|
|
elif benchmark_name == "mmlu_pro":
|
|
prompt, gold = format_mmlu_pro_question(doc)
|
|
valid_letters = "ABCDEFGHIJ"[: len(doc["options"])]
|
|
elif benchmark_name.startswith("aime"):
|
|
prompt, gold_int = format_aime_question(doc)
|
|
gold = str(gold_int)
|
|
elif benchmark_name == "humaneval":
|
|
prompt, exec_meta = format_humaneval_question(doc)
|
|
gold = "pass"
|
|
elif benchmark_name == "livecodebench":
|
|
prompt, system_msg, exec_meta = format_livecodebench_question(doc)
|
|
gold = "pass"
|
|
else:
|
|
raise ValueError(f"Unknown benchmark: {benchmark_name}")
|
|
|
|
async with semaphore:
|
|
t0 = time.monotonic()
|
|
api_result = await call_with_retries(
|
|
http_client,
|
|
base_url,
|
|
model,
|
|
prompt,
|
|
temperature,
|
|
max_tokens,
|
|
timeout,
|
|
system_message=system_msg,
|
|
reasoning_effort=reasoning_effort,
|
|
top_p=top_p,
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
|
|
if api_result is None:
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response="",
|
|
extracted_answer=None,
|
|
gold_answer=gold,
|
|
correct=False,
|
|
error="API failure after retries",
|
|
elapsed_s=elapsed,
|
|
)
|
|
else:
|
|
response = api_result.content
|
|
stats = {
|
|
"prompt_tokens": api_result.prompt_tokens,
|
|
"completion_tokens": api_result.completion_tokens,
|
|
"reasoning_tokens": api_result.reasoning_tokens,
|
|
"elapsed_s": elapsed,
|
|
}
|
|
|
|
if config.kind == "mc":
|
|
extracted = extract_mc_answer(response, valid_letters)
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer=extracted,
|
|
gold_answer=gold,
|
|
correct=(extracted == gold) if extracted else False,
|
|
**stats,
|
|
)
|
|
elif config.kind == "math":
|
|
extracted = extract_boxed_answer(response)
|
|
correct = (
|
|
check_aime_answer(extracted, int(gold)) if extracted else False
|
|
)
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer=extracted,
|
|
gold_answer=gold,
|
|
correct=correct,
|
|
**stats,
|
|
)
|
|
elif config.kind == "code":
|
|
# HumanEval needs preserved indentation (function body completion)
|
|
keep_indent = benchmark_name == "humaneval"
|
|
code = extract_code_block(response, preserve_indent=keep_indent)
|
|
if code is None:
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer=None,
|
|
gold_answer=gold,
|
|
correct=False,
|
|
error="No code block extracted",
|
|
**stats,
|
|
)
|
|
elif benchmark_name == "humaneval":
|
|
passed, diag = run_humaneval_test(
|
|
exec_meta["problem"],
|
|
code,
|
|
)
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer="pass" if passed else "fail",
|
|
gold_answer=gold,
|
|
correct=passed,
|
|
error=diag if not passed else None,
|
|
**stats,
|
|
)
|
|
elif benchmark_name == "livecodebench":
|
|
passed, diag = run_livecodebench_test(
|
|
code,
|
|
exec_meta["sample"],
|
|
)
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer="pass" if passed else "fail",
|
|
gold_answer=gold,
|
|
correct=passed,
|
|
error=diag if not passed else None,
|
|
**stats,
|
|
)
|
|
else:
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer=None,
|
|
gold_answer=gold,
|
|
correct=False,
|
|
error="Unknown code benchmark",
|
|
**stats,
|
|
)
|
|
else:
|
|
result = QuestionResult(
|
|
question_id=idx,
|
|
prompt=prompt,
|
|
response=response,
|
|
extracted_answer=None,
|
|
gold_answer=gold,
|
|
correct=False,
|
|
error="Unsupported kind",
|
|
**stats,
|
|
)
|
|
|
|
results[idx] = result
|
|
|
|
async with lock:
|
|
completed += 1
|
|
n = completed
|
|
if n % max(1, total // 20) == 0 or n == total:
|
|
correct_so_far = sum(1 for r in results if r is not None and r.correct)
|
|
answered = sum(1 for r in results if r is not None)
|
|
logger.info(
|
|
f" [{n}/{total}] {correct_so_far}/{answered} correct "
|
|
f"({correct_so_far / max(answered, 1):.1%})"
|
|
)
|
|
|
|
async with httpx.AsyncClient() as http_client:
|
|
tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
return [r for r in results if r is not None]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Results display
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def print_results(
|
|
benchmark_name: str,
|
|
results: list[QuestionResult],
|
|
concurrency: int | None = None,
|
|
) -> dict[str, Any]:
|
|
total = len(results)
|
|
correct = sum(r.correct for r in results)
|
|
errors = sum(1 for r in results if r.error)
|
|
no_extract = sum(1 for r in results if r.extracted_answer is None and not r.error)
|
|
accuracy = correct / max(total, 1)
|
|
|
|
total_prompt_tokens = sum(r.prompt_tokens for r in results)
|
|
total_completion_tokens = sum(r.completion_tokens for r in results)
|
|
total_reasoning_tokens = sum(r.reasoning_tokens for r in results)
|
|
total_elapsed = sum(r.elapsed_s for r in results)
|
|
wall_clock = max(r.elapsed_s for r in results) if results else 0.0
|
|
avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0
|
|
|
|
label = f"[c={concurrency}] " if concurrency is not None else ""
|
|
print(f"\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})")
|
|
tok_line = f" tokens: {total_prompt_tokens:,} prompt + {total_completion_tokens:,} completion"
|
|
if total_reasoning_tokens > 0:
|
|
tok_line += f" ({total_reasoning_tokens:,} reasoning)"
|
|
tok_line += (
|
|
f" | avg gen tps: {avg_gen_tps:.1f}"
|
|
f" | total time: {total_elapsed:.1f}s wall clock: {wall_clock:.1f}s"
|
|
)
|
|
print(tok_line)
|
|
if errors:
|
|
print(f" API errors: {errors}")
|
|
if no_extract:
|
|
print(f" No answer extracted: {no_extract}")
|
|
|
|
return {
|
|
"benchmark": benchmark_name,
|
|
"accuracy": accuracy,
|
|
"correct": correct,
|
|
"total": total,
|
|
"errors": errors,
|
|
"no_extract": no_extract,
|
|
"total_prompt_tokens": total_prompt_tokens,
|
|
"total_completion_tokens": total_completion_tokens,
|
|
"total_reasoning_tokens": total_reasoning_tokens,
|
|
"total_elapsed_s": total_elapsed,
|
|
"wall_clock_s": wall_clock,
|
|
"avg_gen_tps": avg_gen_tps,
|
|
}
|
|
|
|
|
|
def print_comparison(
|
|
benchmark_name: str,
|
|
results_by_c: dict[int, list[QuestionResult]],
|
|
) -> None:
|
|
levels = sorted(results_by_c.keys())
|
|
print(f"\n{'=' * 70}")
|
|
print(f"COMPARISON: {benchmark_name}")
|
|
print(f"{'=' * 70}")
|
|
|
|
header = f"{'Concurrency':<15} {'Accuracy':>10} {'Correct':>10} {'Total':>10} {'Comp Tokens':>12} {'Wall Clock':>12} {'Avg Gen TPS':>12}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
for c in levels:
|
|
r = results_by_c[c]
|
|
correct = sum(q.correct for q in r)
|
|
total = len(r)
|
|
comp_tok = sum(q.completion_tokens for q in r)
|
|
total_elapsed = sum(q.elapsed_s for q in r)
|
|
avg_tps = comp_tok / total_elapsed if total_elapsed > 0 else 0.0
|
|
wall = max(q.elapsed_s for q in r) if r else 0.0
|
|
print(
|
|
f"c={c:<13} {correct / max(total, 1):>10.1%} {correct:>10} {total:>10}"
|
|
f" {comp_tok:>12,} {wall:>11.1f}s {avg_tps:>12.1f}"
|
|
)
|
|
|
|
if len(levels) >= 2:
|
|
base_results = results_by_c[levels[0]]
|
|
test_results = results_by_c[levels[-1]]
|
|
changed = sum(
|
|
1
|
|
for br, tr in zip(base_results, test_results, strict=True)
|
|
if br.correct != tr.correct
|
|
)
|
|
total = min(len(base_results), len(test_results))
|
|
print(
|
|
f"\nQuestions with different correctness (c={levels[0]} vs c={levels[-1]}): {changed}/{total}"
|
|
)
|
|
if changed == 0:
|
|
print("Batching produced identical quality.")
|
|
elif changed <= total * 0.01:
|
|
print("Negligible quality difference from batching.")
|
|
else:
|
|
print(
|
|
f"WARNING: {changed / max(total, 1) * 100:.1f}% of questions changed."
|
|
)
|
|
print()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Interactive task picker
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def pick_tasks_interactive() -> list[str]:
|
|
import termios
|
|
import tty
|
|
|
|
if not sys.stdin.isatty():
|
|
logger.error("No --tasks specified and stdin is not a terminal.")
|
|
return []
|
|
|
|
items = [(name, cfg.description) for name, cfg in BENCHMARKS.items()]
|
|
selected: list[bool] = [False] * len(items)
|
|
cursor = 0
|
|
total_lines = len(items) + 4
|
|
|
|
def write(s: str) -> None:
|
|
sys.stdout.write(s)
|
|
|
|
def render(first: bool = False) -> None:
|
|
if not first:
|
|
write(f"\033[{total_lines}A")
|
|
write("\033[J")
|
|
write(
|
|
"\033[1mSelect benchmarks\033[0m (up/down, space toggle, enter confirm, q quit)\r\n\r\n"
|
|
)
|
|
for i, (name, desc) in enumerate(items):
|
|
marker = ">" if i == cursor else " "
|
|
check = "x" if selected[i] else " "
|
|
line = f" {marker} [{check}] {name:<17} {desc}"
|
|
write(f"\033[7m{line}\033[0m\r\n" if i == cursor else f"{line}\r\n")
|
|
write(f"\r\n {sum(selected)} selected\r\n")
|
|
sys.stdout.flush()
|
|
|
|
fd = sys.stdin.fileno()
|
|
old = termios.tcgetattr(fd)
|
|
try:
|
|
tty.setraw(fd)
|
|
write("\033[?25l")
|
|
render(first=True)
|
|
while True:
|
|
ch = sys.stdin.read(1)
|
|
if ch in ("q", "\x03"):
|
|
write("\033[?25h\033[0m\r\n")
|
|
return []
|
|
elif ch in ("\r", "\n"):
|
|
break
|
|
elif ch == " ":
|
|
selected[cursor] = not selected[cursor]
|
|
elif ch == "\x1b":
|
|
seq = sys.stdin.read(2)
|
|
if seq == "[A":
|
|
cursor = (cursor - 1) % len(items)
|
|
elif seq == "[B":
|
|
cursor = (cursor + 1) % len(items)
|
|
render()
|
|
finally:
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old)
|
|
write("\033[?25h\033[0m\r\n")
|
|
sys.stdout.flush()
|
|
|
|
chosen = [name for (name, _), sel in zip(items, selected, strict=True) if sel]
|
|
if chosen:
|
|
logger.info(f"Selected: {', '.join(chosen)}")
|
|
return chosen
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Results persistence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def save_results(
|
|
results_dir: str,
|
|
benchmark_name: str,
|
|
model: str,
|
|
concurrency: int,
|
|
results: list[QuestionResult],
|
|
scores: dict[str, Any],
|
|
cluster: dict[str, Any] | None = None,
|
|
) -> Path:
|
|
out_dir = Path(results_dir) / model.replace("/", "_") / benchmark_name
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
ts = time.strftime("%Y%m%d_%H%M%S")
|
|
path = out_dir / f"c{concurrency}_{ts}.json"
|
|
|
|
data: dict[str, Any] = {
|
|
"benchmark": benchmark_name,
|
|
"model": model,
|
|
"concurrency": concurrency,
|
|
**({"cluster": cluster} if cluster else {}),
|
|
"scores": scores,
|
|
"results": [
|
|
{
|
|
"question_id": r.question_id,
|
|
"prompt": r.prompt,
|
|
"response": r.response,
|
|
"extracted_answer": r.extracted_answer,
|
|
"gold_answer": r.gold_answer,
|
|
"correct": r.correct,
|
|
"error": r.error,
|
|
"prompt_tokens": r.prompt_tokens,
|
|
"completion_tokens": r.completion_tokens,
|
|
"reasoning_tokens": r.reasoning_tokens,
|
|
"elapsed_s": round(r.elapsed_s, 2),
|
|
}
|
|
for r in results
|
|
],
|
|
}
|
|
with open(path, "w") as f:
|
|
json.dump(data, f, indent=2)
|
|
logger.info(f"Results saved to {path}")
|
|
return path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def parse_int_list(values: list[str]) -> list[int]:
|
|
items: list[int] = []
|
|
for v in values:
|
|
for part in v.split(","):
|
|
if part.strip():
|
|
items.append(int(part.strip()))
|
|
return items
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser(
|
|
prog="exo-eval",
|
|
description="Quality evaluation for exo — matches Artificial Analysis methodology.",
|
|
)
|
|
add_common_instance_args(ap)
|
|
|
|
ap.add_argument(
|
|
"--tasks",
|
|
default=None,
|
|
help="Comma-separated benchmark names. Omit for interactive picker.",
|
|
)
|
|
ap.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=None,
|
|
help="Max questions per benchmark (for fast iteration).",
|
|
)
|
|
|
|
reasoning_group = ap.add_mutually_exclusive_group()
|
|
reasoning_group.add_argument(
|
|
"--reasoning",
|
|
action="store_true",
|
|
default=None,
|
|
help="Force reasoning-model settings (temperature=0.6, max_tokens=65536).",
|
|
)
|
|
reasoning_group.add_argument(
|
|
"--no-reasoning",
|
|
action="store_true",
|
|
default=False,
|
|
help="Force non-reasoning settings (temperature=0, max_tokens=16384).",
|
|
)
|
|
|
|
ap.add_argument(
|
|
"--temperature", type=float, default=None, help="Override temperature."
|
|
)
|
|
ap.add_argument("--top-p", type=float, default=None, help="Override top_p.")
|
|
ap.add_argument(
|
|
"--max-tokens", type=int, default=None, help="Override max output tokens."
|
|
)
|
|
ap.add_argument(
|
|
"--num-concurrent",
|
|
type=int,
|
|
default=1,
|
|
help="Concurrent API requests (default: 1).",
|
|
)
|
|
ap.add_argument(
|
|
"--compare-concurrency",
|
|
nargs="+",
|
|
default=None,
|
|
help="Run at multiple concurrency levels and compare. E.g. --compare-concurrency 1,4",
|
|
)
|
|
ap.add_argument(
|
|
"--request-timeout",
|
|
type=float,
|
|
default=None,
|
|
help="Per-request timeout in seconds (default: no timeout).",
|
|
)
|
|
ap.add_argument(
|
|
"--reasoning-effort",
|
|
default=None,
|
|
choices=["low", "medium", "high"],
|
|
help="Override reasoning effort (default: 'high' for reasoning models, none for non-reasoning).",
|
|
)
|
|
ap.add_argument(
|
|
"--difficulty",
|
|
default=None,
|
|
choices=["easy", "medium", "hard"],
|
|
help="Filter by difficulty (livecodebench only). E.g. --difficulty hard",
|
|
)
|
|
ap.add_argument(
|
|
"--results-dir",
|
|
default="eval_results",
|
|
help="Directory for result JSON files (default: eval_results).",
|
|
)
|
|
ap.add_argument(
|
|
"--skip-instance-setup",
|
|
action="store_true",
|
|
help="Skip exo instance management (assumes model is already running).",
|
|
)
|
|
|
|
args, _ = ap.parse_known_args()
|
|
|
|
# Resolve tasks
|
|
if args.tasks:
|
|
task_names = [t.strip() for t in args.tasks.split(",") if t.strip()]
|
|
else:
|
|
task_names = pick_tasks_interactive()
|
|
if not task_names:
|
|
return 0
|
|
|
|
for t in task_names:
|
|
if t not in BENCHMARKS:
|
|
logger.error(f"Unknown benchmark '{t}'. Available: {', '.join(BENCHMARKS)}")
|
|
return 1
|
|
|
|
# Instance management
|
|
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
|
instance_id: str | None = None
|
|
|
|
if not args.skip_instance_setup:
|
|
short_id, full_model_id = resolve_model_short_id(
|
|
client,
|
|
args.model,
|
|
force_download=args.force_download,
|
|
)
|
|
selected = settle_and_fetch_placements(
|
|
client,
|
|
full_model_id,
|
|
args,
|
|
settle_timeout=args.settle_timeout,
|
|
)
|
|
if not selected:
|
|
logger.error("No valid placements matched your filters.")
|
|
return 1
|
|
|
|
selected.sort(
|
|
key=lambda p: (
|
|
str(p.get("instance_meta", "")),
|
|
str(p.get("sharding", "")),
|
|
-nodes_used_in_instance(p["instance"]),
|
|
),
|
|
reverse=True,
|
|
)
|
|
preview = selected[0]
|
|
instance = preview["instance"]
|
|
instance_id = instance_id_from_instance(instance)
|
|
|
|
logger.info(
|
|
f"PLACEMENT: {preview['sharding']} / {preview['instance_meta']} / "
|
|
f"nodes={nodes_used_in_instance(instance)}"
|
|
)
|
|
|
|
settle_deadline = (
|
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
|
)
|
|
download_duration = run_planning_phase(
|
|
client,
|
|
full_model_id,
|
|
preview,
|
|
args.danger_delete_downloads,
|
|
args.timeout,
|
|
settle_deadline,
|
|
)
|
|
if download_duration is not None:
|
|
logger.info(f"Download: {download_duration:.1f}s")
|
|
|
|
client.request_json("POST", "/instance", body={"instance": instance})
|
|
try:
|
|
wait_for_instance_ready(client, instance_id)
|
|
except (RuntimeError, TimeoutError) as e:
|
|
logger.error(f"Failed to initialize: {e}")
|
|
with contextlib.suppress(ExoHttpError):
|
|
client.request_json("DELETE", f"/instance/{instance_id}")
|
|
return 1
|
|
time.sleep(1)
|
|
cluster_snapshot = capture_cluster_snapshot(client)
|
|
else:
|
|
full_model_id = args.model
|
|
cluster_snapshot = None
|
|
|
|
# Auto-detect reasoning from model config
|
|
model_config = load_model_config(full_model_id)
|
|
if args.reasoning:
|
|
is_reasoning = True
|
|
elif args.no_reasoning:
|
|
is_reasoning = False
|
|
elif model_config is not None:
|
|
is_reasoning = model_config.get("reasoning", False)
|
|
logger.info(
|
|
f"Auto-detected from config: {model_config['name']} → "
|
|
f"{'reasoning' if is_reasoning else 'non-reasoning'}"
|
|
)
|
|
else:
|
|
is_reasoning = False
|
|
logger.warning(
|
|
f"Model '{full_model_id}' not found in eval_configs/models.toml. "
|
|
f"Defaulting to non-reasoning. Use --reasoning to override."
|
|
)
|
|
|
|
# Resolve temperature, max_tokens, reasoning_effort
|
|
# Priority: CLI flag > per-model config > global defaults
|
|
cfg = model_config or {}
|
|
|
|
if args.temperature is not None:
|
|
temperature = args.temperature
|
|
elif "temperature" in cfg:
|
|
temperature = float(cfg["temperature"])
|
|
else:
|
|
temperature = (
|
|
TEMPERATURE_REASONING if is_reasoning else TEMPERATURE_NON_REASONING
|
|
)
|
|
|
|
if args.max_tokens is not None:
|
|
max_tokens = args.max_tokens
|
|
elif "max_tokens" in cfg:
|
|
max_tokens = int(cfg["max_tokens"])
|
|
else:
|
|
max_tokens = REASONING_MAX_TOKENS if is_reasoning else DEFAULT_MAX_TOKENS
|
|
|
|
if args.top_p is not None:
|
|
top_p: float | None = args.top_p
|
|
elif "top_p" in cfg:
|
|
top_p = float(cfg["top_p"])
|
|
else:
|
|
top_p = None # let server use its default
|
|
|
|
if args.reasoning_effort is not None:
|
|
reasoning_effort = args.reasoning_effort
|
|
elif "reasoning_effort" in cfg:
|
|
reasoning_effort = str(cfg["reasoning_effort"])
|
|
else:
|
|
reasoning_effort = "high" if is_reasoning else None
|
|
base_url = f"http://{args.host}:{args.port}"
|
|
|
|
logger.info(f"Model: {full_model_id}")
|
|
logger.info(
|
|
f"Settings: temperature={temperature}, max_tokens={max_tokens}, "
|
|
+ (f"top_p={top_p}, " if top_p is not None else "")
|
|
+ f"reasoning={'yes' if is_reasoning else 'no'}"
|
|
+ (f", reasoning_effort={reasoning_effort}" if reasoning_effort else "")
|
|
)
|
|
|
|
try:
|
|
if args.compare_concurrency:
|
|
concurrency_levels = parse_int_list(args.compare_concurrency)
|
|
for task_name in task_names:
|
|
results_by_c: dict[int, list[QuestionResult]] = {}
|
|
for c in concurrency_levels:
|
|
logger.info(f"\n{'=' * 50}")
|
|
logger.info(f"Running {task_name} at concurrency={c}")
|
|
results = asyncio.run(
|
|
evaluate_benchmark(
|
|
task_name,
|
|
base_url,
|
|
full_model_id,
|
|
temperature,
|
|
max_tokens,
|
|
concurrency=c,
|
|
limit=args.limit,
|
|
timeout=args.request_timeout,
|
|
reasoning_effort=reasoning_effort,
|
|
top_p=top_p,
|
|
difficulty=args.difficulty,
|
|
)
|
|
)
|
|
if results:
|
|
scores = print_results(task_name, results, concurrency=c)
|
|
save_results(
|
|
args.results_dir,
|
|
task_name,
|
|
full_model_id,
|
|
c,
|
|
results,
|
|
scores,
|
|
cluster=cluster_snapshot,
|
|
)
|
|
results_by_c[c] = results
|
|
if len(results_by_c) >= 2:
|
|
print_comparison(task_name, results_by_c)
|
|
else:
|
|
for task_name in task_names:
|
|
results = asyncio.run(
|
|
evaluate_benchmark(
|
|
task_name,
|
|
base_url,
|
|
full_model_id,
|
|
temperature,
|
|
max_tokens,
|
|
concurrency=args.num_concurrent,
|
|
limit=args.limit,
|
|
timeout=args.request_timeout,
|
|
reasoning_effort=reasoning_effort,
|
|
top_p=top_p,
|
|
difficulty=args.difficulty,
|
|
)
|
|
)
|
|
if results:
|
|
scores = print_results(task_name, results)
|
|
save_results(
|
|
args.results_dir,
|
|
task_name,
|
|
full_model_id,
|
|
args.num_concurrent,
|
|
results,
|
|
scores,
|
|
cluster=cluster_snapshot,
|
|
)
|
|
finally:
|
|
if instance_id is not None:
|
|
try:
|
|
client.request_json("DELETE", f"/instance/{instance_id}")
|
|
except ExoHttpError as e:
|
|
if e.status != 404:
|
|
raise
|
|
wait_for_instance_gone(client, instance_id)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|