From f2a0db4e23717196032fb9035cd487385e8d0661 Mon Sep 17 00:00:00 2001 From: ciaranbor <81697641+ciaranbor@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:53:43 +0100 Subject: [PATCH] Extend bench/eval tooling (#1905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Extend bench/eval tooling with robustness features, streaming support, and align model configs with vllm eval for reproducible comparisons. ## Changes - **exo_eval**: Checkpoint/resume (JSONL), instance health monitoring + early abort, `top_k`/`min_p`/`enable_thinking` params, LCB `--release-version`/`--offset` - **exo_bench**: Streaming SSE (`--stream`), Kimi tokenizer fix for transformers 5.x - **Both tools**: Auto-detect running instances instead of requiring `--skip-instance-setup`; `--fresh-instance` to override - **harness**: SSE streaming client, `find_existing_instance()` shared helper, removed download timeout, settle-timeout default 0→7200s - **models.toml**: Added `enable_thinking`, aligned `max_tokens`/temps with vllm, added new models - **API**: Streaming SSE for `/bench/chat/completions` ## Why It Works - Checkpoint/resume uses append-only JSONL + skip-on-load so interrupted evals resume without re-running completed questions - Health monitoring races an `asyncio.Event` against API calls for fast abort when the instance dies - Auto-detection queries `/state` for existing instances matching the model ID before attempting placement - Streaming reuses the existing `generate_chat_stream` infrastructure from the regular chat endpoint --- bench/eval_configs/models.toml | 99 +++++-- bench/exo_bench.py | 285 ++++++++++++------- bench/exo_eval.py | 483 +++++++++++++++++++++++++++++---- bench/harness.py | 54 +++- src/exo/api/main.py | 22 +- 5 files changed, 762 insertions(+), 181 deletions(-) diff --git a/bench/eval_configs/models.toml b/bench/eval_configs/models.toml index 477e96cd9..d33698fe1 100644 --- a/bench/eval_configs/models.toml +++ b/bench/eval_configs/models.toml @@ -7,7 +7,7 @@ # name, patterns, reasoning # # Optional per-model overrides (CLI flags take priority over these): -# temperature, top_p, max_tokens, reasoning_effort +# temperature, top_p, max_tokens, reasoning_effort, enable_thinking # # Fallback defaults (when no per-model config): # reasoning: temperature=1.0, max_tokens=131072, reasoning_effort="high" @@ -18,10 +18,9 @@ # ─── Qwen3.5 (Feb 2026) ───────────────────────────────────────────── # Source: HuggingFace model cards (Qwen/Qwen3.5-*) -# 35B-A3B thinking general: temp=1.0, top_p=0.95, top_k=20 -# 397B thinking: temp=0.6, top_p=0.95, top_k=20 -# Non-thinking: temp=0.7, top_p=0.8, top_k=20 -# max_tokens: 32768 general, 81920 for complex math/code +# Model card recommends: temp=0.6, top_p=0.95, top_k=20 +# We omit top_k to match vllm eval (which doesn't set it). +# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin). [[model]] name = "Qwen3.5 2B" @@ -29,7 +28,8 @@ patterns = ["Qwen3.5-2B"] reasoning = true temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Qwen3.5 9B" @@ -37,7 +37,8 @@ patterns = ["Qwen3.5-9B"] reasoning = true temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Qwen3.5 27B" @@ -45,15 +46,17 @@ patterns = ["Qwen3.5-27B"] reasoning = true temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Qwen3.5 35B A3B" patterns = ["Qwen3.5-35B-A3B"] reasoning = true -temperature = 1.0 +temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Qwen3.5 122B A10B" @@ -61,7 +64,8 @@ patterns = ["Qwen3.5-122B-A10B"] reasoning = true temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Qwen3.5 397B A17B" @@ -69,12 +73,14 @@ patterns = ["Qwen3.5-397B-A17B"] reasoning = true temperature = 0.6 top_p = 0.95 -max_tokens = 81920 +enable_thinking = true +max_tokens = 121072 # ─── Qwen3 (Apr 2025) ─────────────────────────────────────────────── # Source: HuggingFace model cards (Qwen/Qwen3-*) -# Thinking: temp=0.6, top_p=0.95, top_k=20 -# Non-thinking: temp=0.7, top_p=0.8, top_k=20 +# Model card recommends: temp=0.6, top_p=0.95, top_k=20 +# We omit top_k to match vllm eval (which doesn't set it). +# Non-thinking: temp=0.7, top_p=0.8 # max_tokens: 32768 general, 38912 for complex math/code [[model]] @@ -83,6 +89,7 @@ patterns = ["Qwen3-0.6B"] reasoning = true temperature = 0.6 top_p = 0.95 +enable_thinking = true max_tokens = 38912 [[model]] @@ -91,6 +98,7 @@ patterns = ["Qwen3-30B-A3B"] reasoning = true temperature = 0.6 top_p = 0.95 +enable_thinking = true max_tokens = 38912 [[model]] @@ -99,6 +107,7 @@ patterns = ["Qwen3-235B-A22B"] reasoning = true temperature = 0.6 top_p = 0.95 +enable_thinking = true max_tokens = 38912 [[model]] @@ -107,6 +116,7 @@ patterns = ["Qwen3-Next-80B-A3B-Thinking"] reasoning = true temperature = 0.6 top_p = 0.95 +enable_thinking = true max_tokens = 38912 [[model]] @@ -129,9 +139,9 @@ max_tokens = 16384 name = "Qwen3 Coder Next" patterns = ["Qwen3-Coder-Next"] reasoning = false -temperature = 0.7 -top_p = 0.8 -max_tokens = 16384 +temperature = 1.0 +top_p = 0.95 +max_tokens = 121072 # ─── GPT-OSS (OpenAI) ─────────────────────────────────────────────── # Source: OpenAI GitHub README + HuggingFace discussion #21 @@ -165,10 +175,38 @@ patterns = ["DeepSeek-V3.1"] reasoning = true temperature = 0.0 +[[model]] +name = "DeepSeek V3.2" +patterns = ["DeepSeek-V3.2"] +reasoning = true +temperature = 1.0 +top_p = 0.95 +enable_thinking = true + +# ─── NVIDIA Nemotron ─────────────────────────────────────────────────── +# Source: HuggingFace model cards +# All variants: temp=1.0, top_p=0.95, enable_thinking=true + +[[model]] +name = "Nemotron Cascade 2 30B A3B" +patterns = ["Nemotron-Cascade-2-30B-A3B"] +reasoning = true +temperature = 1.0 +top_p = 0.95 +enable_thinking = true + +[[model]] +name = "Nemotron 3 Super 120B A12B" +patterns = ["Nemotron-3-Super-120B-A12B", "NVIDIA-Nemotron-3-Super-120B-A12B"] +reasoning = true +temperature = 1.0 +top_p = 0.95 +enable_thinking = true + # ─── GLM (ZhipuAI / THUDM) ────────────────────────────────────────── # Source: HuggingFace model cards + generation_config.json + docs.z.ai # GLM 4.5+: temp=1.0, top_p=0.95 -# Reasoning tasks: 131072 max_tokens; coding/SWE tasks: temp=0.7 +# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin) [[model]] name = "GLM-5" @@ -176,7 +214,8 @@ patterns = ["GLM-5"] reasoning = true temperature = 1.0 top_p = 0.95 -max_tokens = 131072 +enable_thinking = true +max_tokens = 121072 [[model]] name = "GLM 4.5 Air" @@ -191,7 +230,8 @@ patterns = ["GLM-4.7-"] reasoning = true temperature = 1.0 top_p = 0.95 -max_tokens = 131072 +enable_thinking = true +max_tokens = 121072 # Note: matches both GLM-4.7 and GLM-4.7-Flash # ─── Kimi (Moonshot AI) ───────────────────────────────────────────── @@ -213,7 +253,8 @@ patterns = ["Kimi-K2.5"] reasoning = true temperature = 1.0 top_p = 0.95 -max_tokens = 131072 +enable_thinking = true +max_tokens = 121072 [[model]] name = "Kimi K2 Instruct" @@ -223,7 +264,17 @@ temperature = 0.6 # ─── MiniMax ───────────────────────────────────────────────────────── # Source: HuggingFace model cards + generation_config.json -# All models: temp=1.0, top_p=0.95, top_k=40 +# All models: temp=1.0, top_p=0.95 +# max_tokens=90000 to match vllm eval (100000 context - 10000 safety margin) + +[[model]] +name = "MiniMax M2.7" +patterns = ["MiniMax-M2.7"] +reasoning = true +temperature = 1.0 +top_p = 0.95 +enable_thinking = true +max_tokens = 90000 [[model]] name = "MiniMax M2.5" @@ -231,6 +282,8 @@ patterns = ["MiniMax-M2.5"] reasoning = true temperature = 1.0 top_p = 0.95 +enable_thinking = true +max_tokens = 90000 [[model]] name = "MiniMax M2.1" @@ -251,6 +304,8 @@ patterns = ["Step-3.5-Flash"] reasoning = true temperature = 1.0 top_p = 0.95 +enable_thinking = true +max_tokens = 121072 # ─── Llama (Meta) ─────────────────────────────────────────────────── # Source: generation_config.json + meta-llama/llama-models generation.py diff --git a/bench/exo_bench.py b/bench/exo_bench.py index e2248f721..50d835a29 100644 --- a/bench/exo_bench.py +++ b/bench/exo_bench.py @@ -35,6 +35,7 @@ from harness import ( ExoHttpError, add_common_instance_args, capture_cluster_snapshot, + find_existing_instance, instance_id_from_instance, node_ids_from_instance, nodes_used_in_instance, @@ -79,7 +80,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any: model_path = Path( snapshot_download( model_id, - allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"], + allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model", "*.jinja"], ) ) @@ -277,28 +278,72 @@ def run_one_completion( prompt_sizer: PromptSizer, *, use_prefix_cache: bool = False, + stream: bool = False, ) -> tuple[dict[str, Any], int]: content, pp_tokens = prompt_sizer.build(pp_hint) payload: dict[str, Any] = { "model": model_id, "messages": [{"role": "user", "content": content}], - "stream": False, "max_tokens": tg, "logprobs": False, "use_prefix_cache": use_prefix_cache, } - t0 = time.perf_counter() - out = client.post_bench_chat_completions(payload) - elapsed = time.perf_counter() - t0 + if not stream: + payload["stream"] = False + t0 = time.perf_counter() + out = client.post_bench_chat_completions(payload) + elapsed = time.perf_counter() - t0 - stats = out.get("generation_stats") + stats = out.get("generation_stats") + choices = out.get("choices") or [{}] + message = choices[0].get("message", {}) if choices else {} + content = message.get("content") or "" + preview = content[:200] if content else "" + else: + tokens = 0 + first_token_time = None + t0 = time.perf_counter() + text_parts: list[str] = [] + stats = None - # Extract preview, handling None content (common for thinking models) - choices = out.get("choices") or [{}] - message = choices[0].get("message", {}) if choices else {} - content = message.get("content") or "" - preview = content[:200] if content else "" + for raw_line in client.stream_bench_chat_completions(payload): + line = raw_line.strip() + if line.startswith(": generation_stats "): + with contextlib.suppress(json.JSONDecodeError): + stats = json.loads(line[len(": generation_stats ") :]) + continue + if not line.startswith("data: "): + continue + data = line[6:] + if data == "[DONE]": + break + try: + chunk = json.loads(data) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if delta.get("content"): + if first_token_time is None: + first_token_time = time.perf_counter() + tokens += 1 + text_parts.append(delta["content"]) + except json.JSONDecodeError: + pass + + elapsed = time.perf_counter() - t0 + preview = "".join(text_parts)[:200] + + if not stats: + ttft = (first_token_time - t0) if first_token_time else elapsed + gen_time = elapsed - ttft if tokens > 1 else elapsed + gen_tps = (tokens - 1) / gen_time if tokens > 1 and gen_time > 0 else 0.0 + prompt_tps = pp_tokens / ttft if ttft > 0 else 0.0 + stats = { + "prompt_tokens": pp_tokens, + "generation_tokens": tokens, + "prompt_tps": round(prompt_tps, 2), + "generation_tps": round(gen_tps, 2), + "peak_memory_usage": {"inBytes": 0}, + } return { "elapsed_s": elapsed, @@ -425,6 +470,11 @@ def main() -> int: action="store_true", help="Force all pp×tg combinations (cartesian product) even when lists have equal length.", ) + ap.add_argument( + "--stream", + action="store_true", + help="Use /bench/chat/completions with streaming SSE response (bench=True still applies: no EOS detection, no KV cache).", + ) ap.add_argument( "--no-system-metrics", action="store_true", @@ -490,81 +540,124 @@ def main() -> int: logger.error("[exo-bench] tokenizer usable but prompt sizing failed") raise - selected = settle_and_fetch_placements( - client, full_model_id, args, settle_timeout=args.settle_timeout - ) + # Optionally reuse a running instance for this model + reused_instance_id: str | None = None + if args.reuse_instance: + existing = find_existing_instance(client, full_model_id) + if existing: + reused_instance_id = existing + logger.info(f"Reusing existing instance {reused_instance_id}") + else: + logger.warning( + "--reuse-instance: no existing instance found, creating a new one" + ) - if not selected: - logger.error("No valid placements matched your filters.") - return 1 - - selected.sort( - key=lambda p: ( - str(p.get("instance_meta", "")), - str(p.get("sharding", "")), - -nodes_used_in_instance(p["instance"]), - ), - reverse=True, - ) - - logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}") - logger.info(f"placements: {len(selected)}") - for p in selected: - logger.info( - f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}" + if reused_instance_id is not None: + # Use the existing instance directly — skip placement iteration + selected = [] + download_duration_s = None + else: + selected = settle_and_fetch_placements( + client, full_model_id, args, settle_timeout=args.settle_timeout ) - if args.dry_run: - return 0 + if not selected: + logger.error("No valid placements matched your filters.") + return 1 - settle_deadline = ( - time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None - ) + selected.sort( + key=lambda p: ( + str(p.get("instance_meta", "")), + str(p.get("sharding", "")), + nodes_used_in_instance(p["instance"]), + ), + reverse=True, + ) - logger.info("Planning phase: checking downloads...") - download_duration_s = run_planning_phase( - client, - full_model_id, - selected[0], - args.danger_delete_downloads, - args.timeout, - settle_deadline, - ) - if download_duration_s is not None: - logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)") - else: - logger.info("Download: model already cached") + logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}") + logger.info(f"placements: {len(selected)}") + for p in selected: + logger.info( + f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}" + ) + + if args.dry_run: + return 0 + + settle_deadline = ( + time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None + ) + + logger.info("Planning phase: checking downloads...") + download_duration_s = run_planning_phase( + client, + full_model_id, + selected[0], + args.danger_delete_downloads, + args.timeout, + settle_deadline, + ) + if download_duration_s is not None: + logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)") + else: + logger.info("Download: model already cached") cluster_snapshot = capture_cluster_snapshot(client) all_rows: list[dict[str, Any]] = [] all_system_metrics: dict[str, dict[str, dict[str, float]]] = {} + # If reusing an existing instance, run a single benchmark pass against it + if reused_instance_id is not None: + selected = [None] + for preview in selected: - instance = preview["instance"] - instance_id = instance_id_from_instance(instance) + created_instance = False + if preview is not None: + instance = preview["instance"] + instance_id = instance_id_from_instance(instance) - sharding = str(preview["sharding"]) - instance_meta = str(preview["instance_meta"]) - n_nodes = nodes_used_in_instance(instance) + sharding = str(preview["sharding"]) + instance_meta = str(preview["instance_meta"]) + n_nodes = nodes_used_in_instance(instance) - logger.info("=" * 80) - logger.info( - f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}" - ) + logger.info("=" * 80) + logger.info( + f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}" + ) - client.request_json("POST", "/instance", body={"instance": instance}) - try: - wait_for_instance_ready(client, instance_id) - except (RuntimeError, TimeoutError) as e: - logger.error(f"Failed to initialize placement: {e}") - with contextlib.suppress(ExoHttpError): - client.request_json("DELETE", f"/instance/{instance_id}") - continue + # Delete any existing instances to free resources before placing + try: + state = client.request_json("GET", "/state") + for old_id in list(state.get("instances", {}).keys()): + logger.info(f"Deleting stale instance {old_id}") + with contextlib.suppress(ExoHttpError): + client.request_json("DELETE", f"/instance/{old_id}") + if state.get("instances"): + time.sleep(2) + except Exception as e: + logger.warning(f"Failed to clean up stale instances: {e}") - time.sleep(1) + client.request_json("POST", "/instance", body={"instance": instance}) + try: + wait_for_instance_ready(client, instance_id) + except (RuntimeError, TimeoutError) as e: + logger.error(f"Failed to initialize placement: {e}") + with contextlib.suppress(ExoHttpError): + client.request_json("DELETE", f"/instance/{instance_id}") + continue + + time.sleep(1) + created_instance = True + else: + instance_id = reused_instance_id + sharding = "reused" + instance_meta = "reused" + n_nodes = 0 + logger.info("=" * 80) + logger.info(f"Using existing instance {instance_id}") sampler: SystemMetricsSampler | None = None - if not args.no_system_metrics: + if not args.no_system_metrics and preview is not None: nids = node_ids_from_instance(instance) sampler = SystemMetricsSampler( ExoClient(args.host, args.port, timeout_s=30), @@ -573,16 +666,20 @@ def main() -> int: ) sampler.start() + def _do_one(c: ExoClient, pp: int, tg: int) -> tuple[dict[str, Any], int]: + return run_one_completion( + c, + full_model_id, + pp, + tg, + prompt_sizer, + use_prefix_cache=args.use_prefix_cache, + stream=args.stream, + ) + try: for i in range(args.warmup): - run_one_completion( - client, - full_model_id, - pp_list[0], - tg_list[0], - prompt_sizer, - use_prefix_cache=args.use_prefix_cache, - ) + _do_one(client, pp_list[0], tg_list[0]) logger.debug(f" warmup {i + 1}/{args.warmup} done") # If pp and tg lists have same length, run in tandem (zip) @@ -604,14 +701,7 @@ def main() -> int: # Sequential: single request try: inf_t0 = time.monotonic() - row, actual_pp_tokens = run_one_completion( - client, - full_model_id, - pp, - tg, - prompt_sizer, - use_prefix_cache=args.use_prefix_cache, - ) + row, actual_pp_tokens = _do_one(client, pp, tg) inference_windows.append((inf_t0, time.monotonic())) except Exception as e: logger.error(e) @@ -760,10 +850,12 @@ def main() -> int: gen_tps = per_req_tps * concurrency ptok = mean(x["stats"]["prompt_tokens"] for x in runs) gtok = mean(x["stats"]["generation_tokens"] for x in runs) - peak = mean( - x["stats"]["peak_memory_usage"]["inBytes"] for x in runs - ) + def _peak_bytes(s: dict[str, Any]) -> float: + pm = s["peak_memory_usage"] + return pm.get("inBytes") or pm.get("in_bytes", 0) + + peak = mean(_peak_bytes(x["stats"]) for x in runs) summary = ( f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} " f"prompt_tokens={ptok} gen_tokens={gtok} " @@ -788,15 +880,16 @@ def main() -> int: if placement_metrics: all_system_metrics.update(placement_metrics) - try: - client.request_json("DELETE", f"/instance/{instance_id}") - except ExoHttpError as e: - if e.status != 404: - raise - wait_for_instance_gone(client, instance_id) - logger.debug(f"Deleted instance {instance_id}") + if created_instance and instance_id is not None: + try: + client.request_json("DELETE", f"/instance/{instance_id}") + except ExoHttpError as e: + if e.status != 404: + raise + wait_for_instance_gone(client, instance_id) + logger.debug(f"Deleted instance {instance_id}") - time.sleep(5) + time.sleep(5) output: dict[str, Any] = {"runs": all_rows} if cluster_snapshot: diff --git a/bench/exo_eval.py b/bench/exo_eval.py index fb4d55f3b..6e0c1b403 100644 --- a/bench/exo_eval.py +++ b/bench/exo_eval.py @@ -47,6 +47,7 @@ from harness import ( ExoHttpError, add_common_instance_args, capture_cluster_snapshot, + find_existing_instance, instance_id_from_instance, nodes_used_in_instance, resolve_model_short_id, @@ -62,6 +63,15 @@ from loguru import logger # --------------------------------------------------------------------------- MAX_RETRIES = 30 +INSTANCE_HEALTH_CHECK_AFTER = ( + 3 # Check instance health after this many consecutive failures +) + + +class InstanceFailedError(RuntimeError): + """Raised when the exo instance is detected as failed/gone.""" + + DEFAULT_MAX_TOKENS = 16_384 REASONING_MAX_TOKENS = 131_072 TEMPERATURE_NON_REASONING = 0.0 @@ -271,7 +281,7 @@ def run_humaneval_test( @dataclass class QuestionResult: - question_id: int + question_id: int | str prompt: str response: str extracted_answer: str | None @@ -281,7 +291,11 @@ class QuestionResult: prompt_tokens: int = 0 completion_tokens: int = 0 reasoning_tokens: int = 0 + reasoning_content: str = "" + finish_reason: str = "" elapsed_s: float = 0.0 + power_watts: float = 0.0 + energy_joules: float = 0.0 @dataclass @@ -517,6 +531,10 @@ class ApiResult: prompt_tokens: int completion_tokens: int reasoning_tokens: int + reasoning_content: str = "" + finish_reason: str = "" + power_watts: float = 0.0 + energy_joules: float = 0.0 async def _call_api( @@ -530,6 +548,9 @@ async def _call_api( system_message: str | None = None, reasoning_effort: str | None = None, top_p: float | None = None, + top_k: int | None = None, + min_p: float | None = None, + enable_thinking: bool | None = None, ) -> ApiResult: messages = [] if system_message: @@ -546,6 +567,12 @@ async def _call_api( body["reasoning_effort"] = reasoning_effort if top_p is not None: body["top_p"] = top_p + if top_k is not None: + body["top_k"] = top_k + if min_p is not None: + body["min_p"] = min_p + if enable_thinking is not None: + body["enable_thinking"] = enable_thinking resp = await client.post( f"{base_url}/v1/chat/completions", @@ -554,19 +581,40 @@ async def _call_api( ) resp.raise_for_status() data = resp.json() - content = data["choices"][0]["message"]["content"] - if not content or not content.strip(): + choice = data["choices"][0] + message = choice["message"] + content = message.get("content") or "" + reasoning_content = message.get("reasoning_content") or "" + finish_reason = choice.get("finish_reason") or "" + + # For thinking models, empty content is expected when finish_reason is "length" + if not content.strip() and finish_reason != "length" and not reasoning_content: raise ValueError("Empty response from model") usage = data.get("usage", {}) details = usage.get("completion_tokens_details", {}) + power = data.get("power_usage") or {} return ApiResult( content=content, prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), reasoning_tokens=details.get("reasoning_tokens", 0) if details else 0, + reasoning_content=reasoning_content, + finish_reason=finish_reason, + power_watts=power.get("total_avg_sys_power_watts", 0.0), + energy_joules=power.get("total_energy_joules", 0.0), ) +async def _check_instance_health(base_url: str) -> bool: + """Return True if the exo instance is still reachable.""" + try: + async with httpx.AsyncClient() as c: + resp = await c.get(f"{base_url}/models", timeout=5.0) + return resp.status_code == 200 + except Exception: + return False + + async def call_with_retries( client: httpx.AsyncClient, base_url: str, @@ -578,8 +626,14 @@ async def call_with_retries( system_message: str | None = None, reasoning_effort: str | None = None, top_p: float | None = None, + top_k: int | None = None, + min_p: float | None = None, + enable_thinking: bool | None = None, + instance_failed: asyncio.Event | None = None, ) -> ApiResult | None: for attempt in range(MAX_RETRIES): + if instance_failed and instance_failed.is_set(): + raise InstanceFailedError("Instance already marked as failed") try: return await _call_api( client, @@ -592,8 +646,30 @@ async def call_with_retries( system_message, reasoning_effort, top_p, + top_k, + min_p, + enable_thinking, ) except Exception as e: + is_conn_error = isinstance( + e, + ( + httpx.ConnectError, + httpx.RemoteProtocolError, + ConnectionRefusedError, + OSError, + ), + ) + if ( + is_conn_error + and attempt >= INSTANCE_HEALTH_CHECK_AFTER + and not await _check_instance_health(base_url) + ): + if instance_failed: + instance_failed.set() + raise InstanceFailedError( + f"Instance is down after {attempt + 1} failures: {e}" + ) from e if attempt < MAX_RETRIES - 1: wait = min(2**attempt, 60) logger.warning( @@ -618,10 +694,16 @@ async def evaluate_benchmark( max_tokens: int, concurrency: int = 1, limit: int | None = None, + offset: int = 0, timeout: float | None = None, reasoning_effort: str | None = None, top_p: float | None = None, + top_k: int | None = None, + min_p: float | None = None, + enable_thinking: bool | None = None, difficulty: str | None = None, + checkpoint_path: Path | None = None, + release_version: str | None = None, ) -> list[QuestionResult]: """Run a benchmark. Returns per-question results.""" import datasets @@ -652,7 +734,21 @@ async def evaluate_benchmark( ds = ds.filter(lambda x: x["difficulty"] == difficulty) logger.info(f"Filtered to {len(ds)} {difficulty} problems") + if release_version and "release_version" in ds.column_names: + ds = ds.filter(lambda x: x["release_version"] == release_version) + logger.info( + f"Filtered to {len(ds)} problems with release_version={release_version}" + ) + + # Sort by question_id to match LCB runner ordering (scenario_router.py:60). + # This ensures [offset:offset+limit] slices select the same problems as vllm. + if "question_id" in ds.column_names: + ds = ds.sort("question_id") + total = len(ds) + if offset > 0: + ds = ds.select(range(min(offset, total), total)) + total = len(ds) if limit and limit < total: ds = ds.select(range(limit)) total = limit @@ -660,6 +756,13 @@ async def evaluate_benchmark( logger.info( f"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, " f"temperature={temperature}, max_tokens={max_tokens}" + + (f", top_k={top_k}" if top_k is not None else "") + + (f", min_p={min_p}" if min_p is not None else "") + + ( + f", enable_thinking={enable_thinking}" + if enable_thinking is not None + else "" + ) ) if config.kind == "code": @@ -667,16 +770,64 @@ async def evaluate_benchmark( "Code benchmarks execute model-generated code. Use a sandboxed environment." ) + # Load checkpoint for resume + checkpoint_data: dict[str | int, dict[str, Any]] = {} + if checkpoint_path and checkpoint_path.exists(): + with open(checkpoint_path) as f: + for line in f: + entry = json.loads(line) + checkpoint_data[entry["question_id"]] = entry + logger.info(f"Loaded {len(checkpoint_data)} checkpointed results") + semaphore = asyncio.Semaphore(concurrency) + instance_failed = asyncio.Event() results: list[QuestionResult | None] = [None] * total completed = 0 lock = asyncio.Lock() + def _get_question_id(idx: int, doc: dict) -> str | int: + """Get a stable question ID for checkpointing.""" + if benchmark_name == "livecodebench": + return doc.get("question_id", idx) + elif benchmark_name == "humaneval": + return doc.get("task_id", idx) + return idx + async def process_question( idx: int, doc: dict, http_client: httpx.AsyncClient ) -> None: nonlocal completed system_msg = None + question_id = _get_question_id(idx, doc) + + # Bail out early if instance is already dead + if instance_failed.is_set(): + return + + # Check checkpoint + if question_id in checkpoint_data: + cached = checkpoint_data[question_id] + results[idx] = QuestionResult( + question_id=question_id, + prompt=cached.get("prompt", ""), + response=cached.get("response", ""), + extracted_answer=cached.get("extracted_answer"), + gold_answer=cached.get("gold_answer", ""), + correct=cached.get("correct", False), + error=cached.get("error"), + prompt_tokens=cached.get("prompt_tokens", 0), + completion_tokens=cached.get("completion_tokens", 0), + reasoning_tokens=cached.get("reasoning_tokens", 0), + reasoning_content=cached.get("reasoning_content", ""), + finish_reason=cached.get("finish_reason", ""), + elapsed_s=cached.get("elapsed_s", 0.0), + power_watts=cached.get("power_watts", 0.0), + energy_joules=cached.get("energy_joules", 0.0), + ) + async with lock: + completed += 1 + logger.info(f" [{completed}/{total}] {question_id} (cached)") + return if benchmark_name == "gpqa_diamond": prompt, gold = format_gpqa_question(doc, idx) @@ -697,24 +848,50 @@ async def evaluate_benchmark( raise ValueError(f"Unknown benchmark: {benchmark_name}") async with semaphore: + if instance_failed.is_set(): + return t0 = time.monotonic() - api_result = await call_with_retries( - http_client, - base_url, - model, - prompt, - temperature, - max_tokens, - timeout, - system_message=system_msg, - reasoning_effort=reasoning_effort, - top_p=top_p, - ) + try: + # Race the API call against the instance_failed event + api_task = asyncio.create_task( + call_with_retries( + http_client, + base_url, + model, + prompt, + temperature, + max_tokens, + timeout, + system_message=system_msg, + reasoning_effort=reasoning_effort, + top_p=top_p, + top_k=top_k, + min_p=min_p, + enable_thinking=enable_thinking, + instance_failed=instance_failed, + ) + ) + failed_waiter = asyncio.create_task(instance_failed.wait()) + done, pending = await asyncio.wait( + [api_task, failed_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + for p in pending: + p.cancel() + with contextlib.suppress(asyncio.CancelledError): + await p + if instance_failed.is_set() and api_task not in done: + logger.error(f"Instance failed, aborting {question_id}") + return + api_result = api_task.result() + except InstanceFailedError: + logger.error(f"Instance failed, skipping {question_id}") + return elapsed = time.monotonic() - t0 if api_result is None: result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response="", extracted_answer=None, @@ -729,13 +906,17 @@ async def evaluate_benchmark( "prompt_tokens": api_result.prompt_tokens, "completion_tokens": api_result.completion_tokens, "reasoning_tokens": api_result.reasoning_tokens, + "reasoning_content": api_result.reasoning_content, + "finish_reason": api_result.finish_reason, "elapsed_s": elapsed, + "power_watts": api_result.power_watts, + "energy_joules": api_result.energy_joules, } if config.kind == "mc": extracted = extract_mc_answer(response, valid_letters) result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer=extracted, @@ -749,7 +930,7 @@ async def evaluate_benchmark( check_aime_answer(extracted, int(gold)) if extracted else False ) result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer=extracted, @@ -763,7 +944,7 @@ async def evaluate_benchmark( code = extract_code_block(response, preserve_indent=keep_indent) if code is None: result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer=None, @@ -778,7 +959,7 @@ async def evaluate_benchmark( code, ) result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer="pass" if passed else "fail", @@ -793,7 +974,7 @@ async def evaluate_benchmark( exec_meta["sample"], ) result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer="pass" if passed else "fail", @@ -804,7 +985,7 @@ async def evaluate_benchmark( ) else: result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer=None, @@ -815,7 +996,7 @@ async def evaluate_benchmark( ) else: result = QuestionResult( - question_id=idx, + question_id=question_id, prompt=prompt, response=response, extracted_answer=None, @@ -827,24 +1008,82 @@ async def evaluate_benchmark( results[idx] = result + # Write checkpoint (skip infra failures so they get retried on resume, + # but keep wrong answers — they are legitimate results) + if checkpoint_path is not None and result.response: + _write_checkpoint(checkpoint_path, result) + async with lock: completed += 1 n = completed - if n % max(1, total // 20) == 0 or n == total: - correct_so_far = sum(1 for r in results if r is not None and r.correct) - answered = sum(1 for r in results if r is not None) - logger.info( - f" [{n}/{total}] {correct_so_far}/{answered} correct " - f"({correct_so_far / max(answered, 1):.1%})" - ) + + # Log progress + thinking_info = "" + if result.reasoning_content: + thinking_info = f", {len(result.reasoning_content)} chars thinking" + logger.info( + f" [{n}/{total}] {question_id}: {len(result.response)} chars{thinking_info}, " + f"tokens: {result.prompt_tokens}+{result.completion_tokens} " + f"[{result.finish_reason}]" + + (f" {result.extracted_answer}" if result.extracted_answer else "") + ) + + async def _health_monitor() -> None: + """Periodically check if the instance is still alive.""" + # Wait a bit before first check to let things start + await asyncio.sleep(10) + while not instance_failed.is_set(): + if not await _check_instance_health(base_url): + # Double-check to avoid false positives + await asyncio.sleep(2) + if not await _check_instance_health(base_url): + logger.error("Health monitor: instance is down!") + instance_failed.set() + return + await asyncio.sleep(5) async with httpx.AsyncClient() as http_client: + monitor = asyncio.create_task(_health_monitor()) tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)] await asyncio.gather(*tasks) + monitor.cancel() + with contextlib.suppress(asyncio.CancelledError): + await monitor + + if instance_failed.is_set(): + completed_count = sum(1 for r in results if r is not None) + logger.error( + f"Instance failed! Completed {completed_count}/{total} problems. " + f"Checkpoint saved — restart to resume remaining problems." + ) + raise InstanceFailedError("Instance failed during evaluation") return [r for r in results if r is not None] +def _write_checkpoint(path: Path, result: QuestionResult) -> None: + """Append a single result to the JSONL checkpoint file.""" + entry = { + "question_id": result.question_id, + "prompt": result.prompt, + "response": result.response, + "extracted_answer": result.extracted_answer, + "gold_answer": result.gold_answer, + "correct": result.correct, + "error": result.error, + "prompt_tokens": result.prompt_tokens, + "completion_tokens": result.completion_tokens, + "reasoning_tokens": result.reasoning_tokens, + "reasoning_content": result.reasoning_content, + "finish_reason": result.finish_reason, + "elapsed_s": round(result.elapsed_s, 2), + "power_watts": round(result.power_watts, 2), + "energy_joules": round(result.energy_joules, 2), + } + with open(path, "a") as f: + f.write(json.dumps(entry) + "\n") + + # --------------------------------------------------------------------------- # Results display # --------------------------------------------------------------------------- @@ -867,6 +1106,8 @@ def print_results( total_elapsed = sum(r.elapsed_s for r in results) wall_clock = max(r.elapsed_s for r in results) if results else 0.0 avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0 + total_energy = sum(r.energy_joules for r in results) + avg_power = sum(r.power_watts for r in results) / max(total, 1) label = f"[c={concurrency}] " if concurrency is not None else "" print(f"\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})") @@ -878,6 +1119,10 @@ def print_results( f" | total time: {total_elapsed:.1f}s wall clock: {wall_clock:.1f}s" ) print(tok_line) + if total_energy > 0: + print( + f" power: avg {avg_power:.1f}W | total energy: {total_energy:.1f}J ({total_energy / 3600:.2f}Wh)" + ) if errors: print(f" API errors: {errors}") if no_extract: @@ -896,6 +1141,8 @@ def print_results( "total_elapsed_s": total_elapsed, "wall_clock_s": wall_clock, "avg_gen_tps": avg_gen_tps, + "avg_power_watts": avg_power, + "total_energy_joules": total_energy, } @@ -1053,7 +1300,11 @@ def save_results( "prompt_tokens": r.prompt_tokens, "completion_tokens": r.completion_tokens, "reasoning_tokens": r.reasoning_tokens, + "reasoning_content": r.reasoning_content, + "finish_reason": r.finish_reason, "elapsed_s": round(r.elapsed_s, 2), + "power_watts": round(r.power_watts, 2), + "energy_joules": round(r.energy_joules, 2), } for r in results ], @@ -1069,6 +1320,15 @@ def save_results( # --------------------------------------------------------------------------- +def _checkpoint_path( + results_dir: str, benchmark: str, model: str, concurrency: int +) -> Path: + """Return the JSONL checkpoint path for a benchmark run.""" + out_dir = Path(results_dir) / model.replace("/", "_") / benchmark + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir / f"c{concurrency}.checkpoint.jsonl" + + def parse_int_list(values: list[str]) -> list[int]: items: list[int] = [] for v in values: @@ -1096,6 +1356,12 @@ def main() -> int: default=None, help="Max questions per benchmark (for fast iteration).", ) + ap.add_argument( + "--offset", + type=int, + default=0, + help="Skip first N questions (0-based).", + ) reasoning_group = ap.add_mutually_exclusive_group() reasoning_group.add_argument( @@ -1115,6 +1381,8 @@ def main() -> int: "--temperature", type=float, default=None, help="Override temperature." ) ap.add_argument("--top-p", type=float, default=None, help="Override top_p.") + ap.add_argument("--top-k", type=int, default=None, help="Override top_k.") + ap.add_argument("--min-p", type=float, default=None, help="Override min_p.") ap.add_argument( "--max-tokens", type=int, default=None, help="Override max output tokens." ) @@ -1148,15 +1416,31 @@ def main() -> int: choices=["easy", "medium", "hard"], help="Filter by difficulty (livecodebench only). E.g. --difficulty hard", ) + ap.add_argument( + "--release-version", + default=None, + help="LCB dataset release version (livecodebench only). E.g. release_v5", + ) ap.add_argument( "--results-dir", default="eval_results", help="Directory for result JSON files (default: eval_results).", ) ap.add_argument( - "--skip-instance-setup", + "--enable-thinking", + type=lambda v: v.lower() in ("true", "1", "yes"), + default=None, + help="Enable thinking mode for models that support it.", + ) + ap.add_argument( + "--force", action="store_true", - help="Skip exo instance management (assumes model is already running).", + help="Discard any existing checkpoint and run from scratch.", + ) + ap.add_argument( + "--keep-instance", + action="store_true", + help="Skip deleting the instance after eval (for chaining runs).", ) args, _ = ap.parse_known_args() @@ -1177,13 +1461,26 @@ def main() -> int: # Instance management client = ExoClient(args.host, args.port, timeout_s=args.timeout) instance_id: str | None = None + created_instance = False - if not args.skip_instance_setup: - short_id, full_model_id = resolve_model_short_id( - client, - args.model, - force_download=args.force_download, - ) + _short_id, full_model_id = resolve_model_short_id( + client, + args.model, + force_download=args.force_download, + ) + + # Optionally reuse a running instance for this model + if args.reuse_instance: + existing = find_existing_instance(client, full_model_id) + if existing: + instance_id = existing + logger.info(f"Reusing existing instance {instance_id}") + else: + logger.warning( + "--reuse-instance: no existing instance found, creating a new one" + ) + + if instance_id is None: selected = settle_and_fetch_placements( client, full_model_id, @@ -1198,7 +1495,7 @@ def main() -> int: key=lambda p: ( str(p.get("instance_meta", "")), str(p.get("sharding", "")), - -nodes_used_in_instance(p["instance"]), + nodes_used_in_instance(p["instance"]), ), reverse=True, ) @@ -1225,6 +1522,18 @@ def main() -> int: if download_duration is not None: logger.info(f"Download: {download_duration:.1f}s") + # Delete any existing instances to free resources before placing + try: + state = client.request_json("GET", "/state") + for old_id in list(state.get("instances", {}).keys()): + logger.info(f"Deleting stale instance {old_id}") + with contextlib.suppress(ExoHttpError): + client.request_json("DELETE", f"/instance/{old_id}") + if state.get("instances"): + time.sleep(2) + except Exception as e: + logger.warning(f"Failed to clean up stale instances: {e}") + client.request_json("POST", "/instance", body={"instance": instance}) try: wait_for_instance_ready(client, instance_id) @@ -1234,10 +1543,9 @@ def main() -> int: client.request_json("DELETE", f"/instance/{instance_id}") return 1 time.sleep(1) - cluster_snapshot = capture_cluster_snapshot(client) - else: - full_model_id = args.model - cluster_snapshot = None + created_instance = True + + cluster_snapshot = capture_cluster_snapshot(client) # Auto-detect reasoning from model config model_config = load_model_config(full_model_id) @@ -1291,16 +1599,57 @@ def main() -> int: reasoning_effort = str(cfg["reasoning_effort"]) else: reasoning_effort = "high" if is_reasoning else None + + if args.top_k is not None: + top_k: int | None = args.top_k + elif "top_k" in cfg: + top_k = int(cfg["top_k"]) + else: + top_k = None + + if args.min_p is not None: + min_p: float | None = args.min_p + elif "min_p" in cfg: + min_p = float(cfg["min_p"]) + else: + min_p = None + + if args.enable_thinking is not None: + enable_thinking: bool | None = args.enable_thinking + elif "enable_thinking" in cfg: + enable_thinking = bool(cfg["enable_thinking"]) + else: + enable_thinking = None + base_url = f"http://{args.host}:{args.port}" logger.info(f"Model: {full_model_id}") logger.info( f"Settings: temperature={temperature}, max_tokens={max_tokens}, " + (f"top_p={top_p}, " if top_p is not None else "") + + (f"top_k={top_k}, " if top_k is not None else "") + + (f"min_p={min_p}, " if min_p is not None else "") + f"reasoning={'yes' if is_reasoning else 'no'}" + (f", reasoning_effort={reasoning_effort}" if reasoning_effort else "") + + ( + f", enable_thinking={enable_thinking}" + if enable_thinking is not None + else "" + ) ) + # Common kwargs for evaluate_benchmark + eval_kwargs: dict[str, Any] = { + "reasoning_effort": reasoning_effort, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "enable_thinking": enable_thinking, + "difficulty": args.difficulty, + "offset": args.offset, + "release_version": args.release_version, + } + try: if args.compare_concurrency: concurrency_levels = parse_int_list(args.compare_concurrency) @@ -1309,6 +1658,11 @@ def main() -> int: for c in concurrency_levels: logger.info(f"\n{'=' * 50}") logger.info(f"Running {task_name} at concurrency={c}") + checkpoint_path = _checkpoint_path( + args.results_dir, task_name, full_model_id, c + ) + if args.force and checkpoint_path.exists(): + checkpoint_path.unlink() results = asyncio.run( evaluate_benchmark( task_name, @@ -1319,9 +1673,8 @@ def main() -> int: concurrency=c, limit=args.limit, timeout=args.request_timeout, - reasoning_effort=reasoning_effort, - top_p=top_p, - difficulty=args.difficulty, + checkpoint_path=checkpoint_path, + **eval_kwargs, ) ) if results: @@ -1336,10 +1689,18 @@ def main() -> int: cluster=cluster_snapshot, ) results_by_c[c] = results + # Clean up checkpoint on success + if checkpoint_path.exists(): + checkpoint_path.unlink() if len(results_by_c) >= 2: print_comparison(task_name, results_by_c) else: for task_name in task_names: + checkpoint_path = _checkpoint_path( + args.results_dir, task_name, full_model_id, args.num_concurrent + ) + if args.force and checkpoint_path.exists(): + checkpoint_path.unlink() results = asyncio.run( evaluate_benchmark( task_name, @@ -1350,9 +1711,8 @@ def main() -> int: concurrency=args.num_concurrent, limit=args.limit, timeout=args.request_timeout, - reasoning_effort=reasoning_effort, - top_p=top_p, - difficulty=args.difficulty, + checkpoint_path=checkpoint_path, + **eval_kwargs, ) ) if results: @@ -1366,14 +1726,25 @@ def main() -> int: scores, cluster=cluster_snapshot, ) + # Clean up checkpoint on success + if checkpoint_path.exists(): + checkpoint_path.unlink() finally: - if instance_id is not None: - try: - client.request_json("DELETE", f"/instance/{instance_id}") - except ExoHttpError as e: - if e.status != 404: - raise - wait_for_instance_gone(client, instance_id) + if created_instance and instance_id is not None: + if args.keep_instance: + logger.info(f"Keeping instance {instance_id} (--keep-instance)") + else: + try: + client.request_json("DELETE", f"/instance/{instance_id}") + except ExoHttpError as e: + if e.status != 404: + raise + try: + wait_for_instance_gone(client, instance_id) + except TimeoutError: + logger.warning( + f"Timed out waiting for instance {instance_id} to be deleted" + ) return 0 diff --git a/bench/harness.py b/bench/harness.py index 9c31fe215..3285becf1 100644 --- a/bench/harness.py +++ b/bench/harness.py @@ -6,6 +6,7 @@ import http.client import json import os import time +from collections.abc import Iterator from typing import Any from urllib.parse import urlencode @@ -69,6 +70,30 @@ class ExoClient: def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]: return self.request_json("POST", "/bench/chat/completions", body=payload) + def stream_bench_chat_completions(self, payload: dict[str, Any]) -> Iterator[str]: + """POST /bench/chat/completions with stream=True, yielding raw SSE lines.""" + payload = {**payload, "stream": True} + data = json.dumps(payload).encode("utf-8") + conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s) + try: + conn.request( + "POST", + "/bench/chat/completions", + body=data, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + }, + ) + resp = conn.getresponse() + if resp.status >= 400: + raw = resp.read().decode("utf-8", errors="replace") + raise ExoHttpError(resp.status, resp.reason, raw[:300]) + for line in resp: + yield line.decode("utf-8", errors="replace") + finally: + conn.close() + def get_state_path(self, path: str) -> Any: try: return self.request_json("GET", f"/state/{path}") @@ -462,9 +487,8 @@ def run_planning_phase( ) logger.info(f"Started download on {node_id}") - # Wait for downloads - start = time.time() - while time.time() - start < timeout: + # Wait for downloads (no timeout — poll until complete or failed) + while True: all_done = True for node_id in node_ids: node_downloads = client.get_node_downloads(node_id) or [] @@ -514,9 +538,24 @@ def run_planning_phase( if download_t0 is not None: return time.perf_counter() - download_t0 return None - time.sleep(1) + time.sleep(10) - raise TimeoutError("Downloads did not complete in time") + +def find_existing_instance(client: ExoClient, model_id: str) -> str | None: + """Find an existing running instance for the given model.""" + try: + state = client.request_json("GET", "/state") + except Exception: + return None + for inst_id, inst in state.get("instances", {}).items(): + # Instance structure is nested: {"MlxJacclInstance": {"shardAssignments": {"modelId": ...}}} + for _inst_type, inner in inst.items(): + if not isinstance(inner, dict): + continue + sa = inner.get("shardAssignments", {}) + if sa.get("modelId") == model_id: + return inst_id + return None def add_common_instance_args(ap: argparse.ArgumentParser) -> None: @@ -572,3 +611,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None: action="store_true", help="Delete existing models from smallest to largest to make room for benchmark model.", ) + ap.add_argument( + "--reuse-instance", + action="store_true", + help="Reuse an existing running instance for this model instead of creating a new one.", + ) diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 0b96e924a..565154ea3 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -336,7 +336,9 @@ class API: self.app.post("/v1/chat/completions", response_model=None)( self.chat_completions ) - self.app.post("/bench/chat/completions")(self.bench_chat_completions) + self.app.post("/bench/chat/completions", response_model=None)( + self.bench_chat_completions + ) self.app.post("/v1/images/generations", response_model=None)( self.image_generations ) @@ -829,7 +831,7 @@ class API: async def bench_chat_completions( self, payload: BenchChatCompletionRequest - ) -> BenchChatCompletionResponse: + ) -> BenchChatCompletionResponse | StreamingResponse: task_params = await chat_request_to_text_generation(payload) resolved_model = await self._resolve_and_validate_text_model( ModelId(task_params.model) @@ -846,6 +848,22 @@ class API: command = await self._send_text_generation_with_images(task_params) + if payload.stream: + return StreamingResponse( + with_sse_keepalive( + generate_chat_stream( + command.command_id, + self._token_chunk_stream(command.command_id), + ), + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "close", + "X-Accel-Buffering": "no", + }, + ) + return await self._collect_text_generation_with_stats(command.command_id) async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId: