mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-17 20:40:35 -04:00
## Motivation For using Exo-Bench extensively, there are many cases that we could use prefix caching to speed up the benchmarks, especially when the focus is on the token generation. At the same time, it's very clear that prefix caching decode tokens is not very useful in most current scenarios. Surprisingly, even for non-thinking models, the chat template means that a continued conversation will be formatted such that the existing cache is not effective. We already (slightly accidentally) do this for the batch generator - we should do it for the sequential generator too. ## Changes We can now speed up exo bench by having a use prefix caching flag. Of course, for most accurate pp results, it is better to not have it, but this speeds up tg and large benchmarking significantly. Updated methodology to match ## Test Plan ### Manual Testing Tested on many configurations that the difference in results is negligible, even with multiple --pp options.
769 lines
30 KiB
Python
769 lines
30 KiB
Python
# type: ignore
|
||
#!/usr/bin/env python3
|
||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||
|
||
Tests whether models correctly:
|
||
- Trigger tool calls when appropriate
|
||
- Return valid JSON arguments matching function schemas
|
||
- Handle multi-turn tool use (call -> result -> final answer)
|
||
- Avoid calling tools when unnecessary
|
||
|
||
Start exo with a model first, then run:
|
||
uv run python tool_call_eval.py --model <model-id>
|
||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import contextlib
|
||
import itertools
|
||
import json
|
||
import sys
|
||
import threading
|
||
import time
|
||
from collections.abc import Callable
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from pathlib import Path
|
||
from statistics import mean
|
||
from typing import Any
|
||
|
||
from harness import (
|
||
ExoClient,
|
||
ExoHttpError,
|
||
add_common_instance_args,
|
||
capture_cluster_snapshot,
|
||
instance_id_from_instance,
|
||
node_ids_from_instance,
|
||
nodes_used_in_instance,
|
||
resolve_model_short_id,
|
||
run_planning_phase,
|
||
settle_and_fetch_placements,
|
||
wait_for_instance_gone,
|
||
wait_for_instance_ready,
|
||
)
|
||
from loguru import logger
|
||
from transformers import AutoTokenizer
|
||
|
||
# Monkey-patch for transformers 5.x compatibility
|
||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||
# which was moved in transformers 5.0.0rc2
|
||
try:
|
||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||
|
||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||
except ImportError:
|
||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||
|
||
|
||
def load_tokenizer_for_bench(model_id: str) -> Any:
|
||
"""
|
||
Load tokenizer for benchmarking, with special handling for Kimi models.
|
||
|
||
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
|
||
This function replicates the logic from utils_mlx.py for bench compatibility.
|
||
"""
|
||
model_id_lower = model_id.lower()
|
||
|
||
if "kimi-k2" in model_id_lower:
|
||
import importlib.util
|
||
import types
|
||
|
||
from huggingface_hub import snapshot_download
|
||
|
||
# Download/get the model path
|
||
model_path = Path(
|
||
snapshot_download(
|
||
model_id,
|
||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
|
||
)
|
||
)
|
||
|
||
sys.path.insert(0, str(model_path))
|
||
|
||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||
if tool_decl_path.exists():
|
||
spec = importlib.util.spec_from_file_location(
|
||
"tool_declaration_ts", tool_decl_path
|
||
)
|
||
if spec and spec.loader:
|
||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||
spec.loader.exec_module(tool_decl_module)
|
||
|
||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||
tok_path = model_path / "tokenization_kimi.py"
|
||
source = tok_path.read_text()
|
||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||
if spec:
|
||
tok_module = types.ModuleType("tokenization_kimi")
|
||
tok_module.__file__ = str(tok_path)
|
||
sys.modules["tokenization_kimi"] = tok_module
|
||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
|
||
else:
|
||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||
|
||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
|
||
|
||
# Patch encode to use internal tiktoken model directly
|
||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||
def _patched_encode(text: str, **kwargs: object) -> list[int]:
|
||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
|
||
|
||
hf_tokenizer.encode = _patched_encode
|
||
|
||
return hf_tokenizer
|
||
|
||
# Default: use AutoTokenizer
|
||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||
|
||
|
||
def format_peak_memory(b: float) -> str:
|
||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||
if b < 1024.0:
|
||
return f"{b:.2f}{unit}"
|
||
b /= 1024.0
|
||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||
|
||
|
||
_SAMPLER_METRICS = ("gpuUsage", "temp", "sysPower", "pcpuUsage", "ecpuUsage")
|
||
|
||
|
||
class SystemMetricsSampler:
|
||
def __init__(self, client: ExoClient, node_ids: list[str], interval_s: float = 1.0):
|
||
self._client = client
|
||
self._node_ids = node_ids
|
||
self._interval_s = interval_s
|
||
self._samples: dict[str, list[tuple[float, dict[str, float]]]] = {
|
||
nid: [] for nid in node_ids
|
||
}
|
||
self._stop = threading.Event()
|
||
self._thread: threading.Thread | None = None
|
||
|
||
def start(self) -> None:
|
||
self._stop.clear()
|
||
self._thread = threading.Thread(target=self._poll_loop, daemon=True)
|
||
self._thread.start()
|
||
|
||
def stop(self) -> None:
|
||
self._stop.set()
|
||
if self._thread:
|
||
self._thread.join(timeout=5)
|
||
|
||
def _poll_loop(self) -> None:
|
||
while not self._stop.is_set():
|
||
t = time.monotonic()
|
||
for nid in self._node_ids:
|
||
try:
|
||
data = self._client.get_node_system(nid)
|
||
if data:
|
||
self._samples[nid].append(
|
||
(t, {k: data.get(k, 0.0) for k in _SAMPLER_METRICS})
|
||
)
|
||
except Exception:
|
||
pass
|
||
self._stop.wait(self._interval_s)
|
||
|
||
def energy_between(self, t0: float, t1: float) -> float:
|
||
total_joules = 0.0
|
||
for _nid, samples in self._samples.items():
|
||
window = [(t, s["sysPower"]) for t, s in samples if t0 <= t <= t1]
|
||
if len(window) >= 2:
|
||
for i in range(1, len(window)):
|
||
dt = window[i][0] - window[i - 1][0]
|
||
avg_power = (window[i][1] + window[i - 1][1]) / 2
|
||
total_joules += avg_power * dt
|
||
elif len(window) == 1:
|
||
total_joules += window[0][1] * (t1 - t0)
|
||
return total_joules
|
||
|
||
def summarize(self) -> dict[str, dict[str, dict[str, float]]]:
|
||
result: dict[str, dict[str, dict[str, float]]] = {}
|
||
for nid, samples in self._samples.items():
|
||
if not samples:
|
||
continue
|
||
metrics: dict[str, dict[str, float]] = {}
|
||
for key in _SAMPLER_METRICS:
|
||
values = [s[key] for t, s in samples]
|
||
metrics[key] = {
|
||
"min": round(min(values), 2),
|
||
"max": round(max(values), 2),
|
||
"mean": round(mean(values), 2),
|
||
"samples": len(values),
|
||
}
|
||
result[nid] = metrics
|
||
return result
|
||
|
||
def print_summary(self, placement_label: str) -> None:
|
||
summary = self.summarize()
|
||
if not summary:
|
||
return
|
||
logger.info(f"--- System Metrics ({placement_label}) ---")
|
||
for nid, metrics in summary.items():
|
||
gpu = metrics.get("gpuUsage", {})
|
||
temp = metrics.get("temp", {})
|
||
power = metrics.get("sysPower", {})
|
||
logger.info(
|
||
f" {nid}: "
|
||
f"GPU {gpu.get('mean', 0) * 100:.0f}% avg ({gpu.get('min', 0) * 100:.0f}–{gpu.get('max', 0) * 100:.0f}%) | "
|
||
f"{temp.get('mean', 0):.1f}°C avg | "
|
||
f"{power.get('mean', 0):.1f}W avg"
|
||
)
|
||
|
||
|
||
def parse_int_list(values: list[str]) -> list[int]:
|
||
items: list[int] = []
|
||
for v in values:
|
||
for part in v.split(","):
|
||
part = part.strip()
|
||
if part:
|
||
items.append(int(part))
|
||
return items
|
||
|
||
|
||
def run_one_completion(
|
||
client: ExoClient,
|
||
model_id: str,
|
||
pp_hint: int,
|
||
tg: int,
|
||
prompt_sizer: PromptSizer,
|
||
*,
|
||
use_prefix_cache: bool = False,
|
||
) -> tuple[dict[str, Any], int]:
|
||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||
payload: dict[str, Any] = {
|
||
"model": model_id,
|
||
"messages": [{"role": "user", "content": content}],
|
||
"stream": False,
|
||
"max_tokens": tg,
|
||
"logprobs": False,
|
||
"use_prefix_cache": use_prefix_cache,
|
||
}
|
||
|
||
t0 = time.perf_counter()
|
||
out = client.post_bench_chat_completions(payload)
|
||
elapsed = time.perf_counter() - t0
|
||
|
||
stats = out.get("generation_stats")
|
||
|
||
# Extract preview, handling None content (common for thinking models)
|
||
choices = out.get("choices") or [{}]
|
||
message = choices[0].get("message", {}) if choices else {}
|
||
content = message.get("content") or ""
|
||
preview = content[:200] if content else ""
|
||
|
||
return {
|
||
"elapsed_s": elapsed,
|
||
"output_text_preview": preview,
|
||
"stats": stats,
|
||
}, pp_tokens
|
||
|
||
|
||
class PromptSizer:
|
||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||
self.tokenizer = tokenizer
|
||
self.atom = atom
|
||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||
self.base_tokens = self.count_fn("")
|
||
|
||
@staticmethod
|
||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||
def count_fn(user_content: str) -> int:
|
||
messages = [{"role": "user", "content": user_content}]
|
||
ids = tokenizer.apply_chat_template(
|
||
messages, tokenize=True, add_generation_prompt=True
|
||
)
|
||
# Fix for transformers 5.x
|
||
if hasattr(ids, "input_ids"):
|
||
ids = ids.input_ids
|
||
return int(len(ids))
|
||
|
||
return count_fn
|
||
|
||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||
target = int(target_prompt_tokens)
|
||
if target < self.base_tokens:
|
||
raise RuntimeError(
|
||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||
)
|
||
|
||
# Estimate tokens per atom using a sample
|
||
sample_count = 100
|
||
sample_content = self.atom * sample_count
|
||
sample_tokens = self.count_fn(sample_content) - self.base_tokens
|
||
tokens_per_atom = sample_tokens / sample_count
|
||
|
||
# Estimate starting point
|
||
needed_tokens = target - self.base_tokens
|
||
estimated_atoms = int(needed_tokens / tokens_per_atom)
|
||
|
||
# Binary search to find exact atom count
|
||
low, high = 0, estimated_atoms * 2 + 100
|
||
while low < high:
|
||
mid = (low + high) // 2
|
||
tok = self.count_fn(self.atom * mid)
|
||
if tok < target:
|
||
low = mid + 1
|
||
else:
|
||
high = mid
|
||
|
||
content = self.atom * low
|
||
tok = self.count_fn(content)
|
||
logger.info(f"{tok=}")
|
||
|
||
if tok != target:
|
||
raise RuntimeError(
|
||
f"Overshot: got {tok} tokens (target {target}). "
|
||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||
)
|
||
|
||
return content, tok
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(
|
||
prog="exo-bench",
|
||
description="Benchmark exo model throughput across placement previews.",
|
||
)
|
||
add_common_instance_args(ap)
|
||
ap.add_argument(
|
||
"--pp",
|
||
nargs="+",
|
||
required=True,
|
||
help="Prompt-size hints (ints). Accepts commas.",
|
||
)
|
||
ap.add_argument(
|
||
"--tg",
|
||
nargs="+",
|
||
required=True,
|
||
help="Generation lengths (ints). Accepts commas.",
|
||
)
|
||
ap.add_argument(
|
||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||
)
|
||
ap.add_argument(
|
||
"--concurrency",
|
||
nargs="+",
|
||
default=["1"],
|
||
help="Concurrency levels (ints). Accepts commas. E.g. --concurrency 1,2,4,8. Default 1.",
|
||
)
|
||
ap.add_argument(
|
||
"--warmup",
|
||
type=int,
|
||
default=0,
|
||
help="Warmup runs per placement (uses first pp/tg).",
|
||
)
|
||
ap.add_argument(
|
||
"--json-out",
|
||
default="bench/results.json",
|
||
help="Write raw per-run results JSON to this path.",
|
||
)
|
||
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
|
||
ap.add_argument(
|
||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||
)
|
||
ap.add_argument(
|
||
"--all-combinations",
|
||
action="store_true",
|
||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||
)
|
||
ap.add_argument(
|
||
"--no-system-metrics",
|
||
action="store_true",
|
||
help="Disable GPU utilization, temperature, and power collection during inference.",
|
||
)
|
||
ap.add_argument(
|
||
"--metrics-interval",
|
||
type=float,
|
||
default=1.0,
|
||
help="System metrics polling interval in seconds (default: 1.0).",
|
||
)
|
||
ap.add_argument(
|
||
"--use-prefix-cache",
|
||
action="store_true",
|
||
help="Enable KV prefix cache during bench (default: disabled for cold-cache measurements).",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
pp_list = parse_int_list(args.pp)
|
||
tg_list = parse_int_list(args.tg)
|
||
if not pp_list or not tg_list:
|
||
logger.error("pp and tg lists must be non-empty")
|
||
return 2
|
||
if args.repeat <= 0:
|
||
logger.error("--repeat must be >= 1")
|
||
return 2
|
||
concurrency_list = parse_int_list(args.concurrency)
|
||
if not concurrency_list or any(c <= 0 for c in concurrency_list):
|
||
logger.error("--concurrency values must be >= 1")
|
||
return 2
|
||
|
||
if args.use_prefix_cache:
|
||
logger.warning(
|
||
"--use-prefix-cache: prompt TPS will be approximate. See METHODOLOGY.md for details."
|
||
)
|
||
if pp_list != sorted(pp_list):
|
||
logger.warning(
|
||
"--pp values are not in ascending order: prompt TPS will be less accurate. Use ascending --pp for best results."
|
||
)
|
||
|
||
# Log pairing mode
|
||
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
|
||
if use_combinations:
|
||
logger.info(
|
||
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
|
||
)
|
||
else:
|
||
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
|
||
|
||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||
short_id, full_model_id = resolve_model_short_id(
|
||
client, args.model, force_download=args.force_download
|
||
)
|
||
|
||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||
if tokenizer is None:
|
||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||
|
||
try:
|
||
prompt_sizer = PromptSizer(tokenizer)
|
||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||
except Exception:
|
||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||
raise
|
||
|
||
selected = settle_and_fetch_placements(
|
||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||
)
|
||
|
||
if not selected:
|
||
logger.error("No valid placements matched your filters.")
|
||
return 1
|
||
|
||
selected.sort(
|
||
key=lambda p: (
|
||
str(p.get("instance_meta", "")),
|
||
str(p.get("sharding", "")),
|
||
-nodes_used_in_instance(p["instance"]),
|
||
),
|
||
reverse=True,
|
||
)
|
||
|
||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||
logger.info(f"placements: {len(selected)}")
|
||
for p in selected:
|
||
logger.info(
|
||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||
)
|
||
|
||
if args.dry_run:
|
||
return 0
|
||
|
||
settle_deadline = (
|
||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||
)
|
||
|
||
logger.info("Planning phase: checking downloads...")
|
||
download_duration_s = run_planning_phase(
|
||
client,
|
||
full_model_id,
|
||
selected[0],
|
||
args.danger_delete_downloads,
|
||
args.timeout,
|
||
settle_deadline,
|
||
)
|
||
if download_duration_s is not None:
|
||
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
|
||
else:
|
||
logger.info("Download: model already cached")
|
||
|
||
cluster_snapshot = capture_cluster_snapshot(client)
|
||
all_rows: list[dict[str, Any]] = []
|
||
all_system_metrics: dict[str, dict[str, dict[str, float]]] = {}
|
||
|
||
for preview in selected:
|
||
instance = preview["instance"]
|
||
instance_id = instance_id_from_instance(instance)
|
||
|
||
sharding = str(preview["sharding"])
|
||
instance_meta = str(preview["instance_meta"])
|
||
n_nodes = nodes_used_in_instance(instance)
|
||
|
||
logger.info("=" * 80)
|
||
logger.info(
|
||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||
)
|
||
|
||
client.request_json("POST", "/instance", body={"instance": instance})
|
||
try:
|
||
wait_for_instance_ready(client, instance_id)
|
||
except (RuntimeError, TimeoutError) as e:
|
||
logger.error(f"Failed to initialize placement: {e}")
|
||
with contextlib.suppress(ExoHttpError):
|
||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||
continue
|
||
|
||
time.sleep(1)
|
||
|
||
sampler: SystemMetricsSampler | None = None
|
||
if not args.no_system_metrics:
|
||
nids = node_ids_from_instance(instance)
|
||
sampler = SystemMetricsSampler(
|
||
ExoClient(args.host, args.port, timeout_s=30),
|
||
nids,
|
||
interval_s=args.metrics_interval,
|
||
)
|
||
sampler.start()
|
||
|
||
try:
|
||
for i in range(args.warmup):
|
||
run_one_completion(
|
||
client,
|
||
full_model_id,
|
||
pp_list[0],
|
||
tg_list[0],
|
||
prompt_sizer,
|
||
use_prefix_cache=args.use_prefix_cache,
|
||
)
|
||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||
|
||
# If pp and tg lists have same length, run in tandem (zip)
|
||
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
|
||
if use_combinations:
|
||
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
|
||
else:
|
||
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
|
||
|
||
for pp, tg in pp_tg_pairs:
|
||
for concurrency in concurrency_list:
|
||
logger.info(f"--- pp={pp} tg={tg} concurrency={concurrency} ---")
|
||
runs: list[dict[str, Any]] = []
|
||
inference_windows: list[tuple[float, float]] = []
|
||
for r in range(args.repeat):
|
||
time.sleep(3)
|
||
|
||
if concurrency <= 1:
|
||
# Sequential: single request
|
||
try:
|
||
inf_t0 = time.monotonic()
|
||
row, actual_pp_tokens = run_one_completion(
|
||
client,
|
||
full_model_id,
|
||
pp,
|
||
tg,
|
||
prompt_sizer,
|
||
use_prefix_cache=args.use_prefix_cache,
|
||
)
|
||
inference_windows.append((inf_t0, time.monotonic()))
|
||
except Exception as e:
|
||
logger.error(e)
|
||
continue
|
||
row.update(
|
||
{
|
||
"model_short_id": short_id,
|
||
"model_id": full_model_id,
|
||
"placement_sharding": sharding,
|
||
"placement_instance_meta": instance_meta,
|
||
"placement_nodes": n_nodes,
|
||
"instance_id": instance_id,
|
||
"pp_tokens": actual_pp_tokens,
|
||
"tg": tg,
|
||
"repeat_index": r,
|
||
"concurrency": 1,
|
||
**(
|
||
{"download_duration_s": download_duration_s}
|
||
if download_duration_s is not None
|
||
else {}
|
||
),
|
||
}
|
||
)
|
||
runs.append(row)
|
||
all_rows.append(row)
|
||
else:
|
||
# Concurrent: fire N requests in parallel
|
||
# Pre-build prompt once, barrier ensures simultaneous dispatch
|
||
content, actual_pp = prompt_sizer.build(pp)
|
||
pre_built_payload: dict[str, Any] = {
|
||
"model": full_model_id,
|
||
"messages": [{"role": "user", "content": content}],
|
||
"stream": False,
|
||
"max_tokens": tg,
|
||
"logprobs": False,
|
||
"use_prefix_cache": args.use_prefix_cache,
|
||
}
|
||
barrier = threading.Barrier(concurrency)
|
||
batch_start = threading.Event()
|
||
batch_t0: float = 0.0
|
||
batch_results: list[tuple[dict[str, Any], int]] = []
|
||
batch_errors = 0
|
||
|
||
def _run_concurrent(
|
||
idx: int,
|
||
_barrier: threading.Barrier = barrier,
|
||
_batch_start: threading.Event = batch_start,
|
||
_payload: dict[str, Any] = pre_built_payload,
|
||
_actual_pp: int = actual_pp,
|
||
) -> tuple[dict[str, Any], int]:
|
||
nonlocal batch_t0
|
||
c = ExoClient(
|
||
args.host, args.port, timeout_s=args.timeout
|
||
)
|
||
if _barrier.wait() == 0:
|
||
batch_t0 = time.perf_counter()
|
||
_batch_start.set()
|
||
else:
|
||
_batch_start.wait()
|
||
t0 = batch_t0
|
||
out = c.post_bench_chat_completions(_payload)
|
||
elapsed = time.perf_counter() - t0
|
||
stats = out.get("generation_stats")
|
||
choices = out.get("choices") or [{}]
|
||
message = (
|
||
choices[0].get("message", {}) if choices else {}
|
||
)
|
||
text = message.get("content") or ""
|
||
return {
|
||
"elapsed_s": elapsed,
|
||
"output_text_preview": text[:200],
|
||
"stats": stats,
|
||
}, _actual_pp
|
||
|
||
inf_t0 = time.monotonic()
|
||
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
||
futures = {
|
||
pool.submit(_run_concurrent, i): i
|
||
for i in range(concurrency)
|
||
}
|
||
for fut in as_completed(futures):
|
||
try:
|
||
batch_results.append(fut.result())
|
||
except Exception as e:
|
||
logger.error(f"Concurrent request failed: {e}")
|
||
batch_errors += 1
|
||
batch_wall_s = (
|
||
max(x["elapsed_s"] for x, _ in batch_results)
|
||
if batch_results
|
||
else time.perf_counter() - batch_t0
|
||
)
|
||
inference_windows.append((inf_t0, time.monotonic()))
|
||
|
||
for idx, (row, actual_pp_tokens) in enumerate(
|
||
batch_results
|
||
):
|
||
row.update(
|
||
{
|
||
"model_short_id": short_id,
|
||
"model_id": full_model_id,
|
||
"placement_sharding": sharding,
|
||
"placement_instance_meta": instance_meta,
|
||
"placement_nodes": n_nodes,
|
||
"instance_id": instance_id,
|
||
"pp_tokens": actual_pp_tokens,
|
||
"tg": tg,
|
||
"repeat_index": r,
|
||
"concurrency": concurrency,
|
||
"concurrent_index": idx,
|
||
**(
|
||
{"download_duration_s": download_duration_s}
|
||
if download_duration_s is not None
|
||
else {}
|
||
),
|
||
}
|
||
)
|
||
runs.append(row)
|
||
all_rows.append(row)
|
||
|
||
if batch_results:
|
||
valid_gen_tps = [
|
||
x["stats"]["generation_tps"]
|
||
for x, _ in batch_results
|
||
if x["stats"]["generation_tps"] > 0
|
||
]
|
||
per_req_tps = (
|
||
max(valid_gen_tps) if valid_gen_tps else 0.0
|
||
)
|
||
agg_gen_tps = per_req_tps * concurrency
|
||
logger.info(
|
||
f"[concurrent {concurrency}x] "
|
||
f"agg_gen_tps={agg_gen_tps:.2f} "
|
||
f"per_req_tps={per_req_tps:.2f} "
|
||
f"wall_s={batch_wall_s:.2f} "
|
||
f"errors={batch_errors}"
|
||
)
|
||
|
||
if runs:
|
||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||
valid_gen = [
|
||
x["stats"]["generation_tps"]
|
||
for x in runs
|
||
if x["stats"]["generation_tps"] > 0
|
||
]
|
||
per_req_tps = max(valid_gen) if valid_gen else 0.0
|
||
gen_tps = per_req_tps * concurrency
|
||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||
peak = mean(
|
||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||
)
|
||
|
||
summary = (
|
||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||
f"peak_memory={format_peak_memory(peak)}"
|
||
)
|
||
if sampler and inference_windows:
|
||
joules = sum(
|
||
sampler.energy_between(t0, t1)
|
||
for t0, t1 in inference_windows
|
||
)
|
||
inf_seconds = sum(t1 - t0 for t0, t1 in inference_windows)
|
||
avg_watts = joules / inf_seconds if inf_seconds > 0 else 0
|
||
summary += f" energy={joules:.1f}J ({avg_watts:.1f}W avg over {inf_seconds:.1f}s inference)"
|
||
logger.info(f"{summary}\n")
|
||
time.sleep(2)
|
||
finally:
|
||
if sampler:
|
||
sampler.stop()
|
||
placement_label = f"{sharding}/{instance_meta}/{n_nodes} nodes"
|
||
sampler.print_summary(placement_label)
|
||
placement_metrics = sampler.summarize()
|
||
if placement_metrics:
|
||
all_system_metrics.update(placement_metrics)
|
||
|
||
try:
|
||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||
except ExoHttpError as e:
|
||
if e.status != 404:
|
||
raise
|
||
wait_for_instance_gone(client, instance_id)
|
||
logger.debug(f"Deleted instance {instance_id}")
|
||
|
||
time.sleep(5)
|
||
|
||
output: dict[str, Any] = {"runs": all_rows}
|
||
if cluster_snapshot:
|
||
output["cluster"] = cluster_snapshot
|
||
if all_system_metrics:
|
||
output["system_metrics"] = all_system_metrics
|
||
|
||
if args.stdout:
|
||
json.dump(output, sys.stdout, indent=2, ensure_ascii=False)
|
||
elif args.json_out:
|
||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||
json.dump(output, f, indent=2, ensure_ascii=False)
|
||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|