Extend bench/eval tooling (#1905)

## Motivation

Extend bench/eval tooling with robustness features, streaming support,
and align model configs with vllm eval for reproducible comparisons.

## Changes

- **exo_eval**: Checkpoint/resume (JSONL), instance health monitoring +
early abort, `top_k`/`min_p`/`enable_thinking` params, LCB
`--release-version`/`--offset`
- **exo_bench**: Streaming SSE (`--stream`), Kimi tokenizer fix for
transformers 5.x
- **Both tools**: Auto-detect running instances instead of requiring
`--skip-instance-setup`; `--fresh-instance` to override
- **harness**: SSE streaming client, `find_existing_instance()` shared
helper, removed download timeout, settle-timeout default 0→7200s
- **models.toml**: Added `enable_thinking`, aligned `max_tokens`/temps
with vllm, added new models
- **API**: Streaming SSE for `/bench/chat/completions`

## Why It Works

- Checkpoint/resume uses append-only JSONL + skip-on-load so interrupted
evals resume without re-running completed questions
- Health monitoring races an `asyncio.Event` against API calls for fast
abort when the instance dies
- Auto-detection queries `/state` for existing instances matching the
model ID before attempting placement
- Streaming reuses the existing `generate_chat_stream` infrastructure
from the regular chat endpoint
This commit is contained in:
ciaranbor
2026-04-27 16:53:43 +01:00
committed by GitHub
parent 37f6f4f6c2
commit f2a0db4e23
5 changed files with 762 additions and 181 deletions

View File

@@ -7,7 +7,7 @@
# name, patterns, reasoning
#
# Optional per-model overrides (CLI flags take priority over these):
# temperature, top_p, max_tokens, reasoning_effort
# temperature, top_p, max_tokens, reasoning_effort, enable_thinking
#
# Fallback defaults (when no per-model config):
# reasoning: temperature=1.0, max_tokens=131072, reasoning_effort="high"
@@ -18,10 +18,9 @@
# ─── Qwen3.5 (Feb 2026) ─────────────────────────────────────────────
# Source: HuggingFace model cards (Qwen/Qwen3.5-*)
# 35B-A3B thinking general: temp=1.0, top_p=0.95, top_k=20
# 397B thinking: temp=0.6, top_p=0.95, top_k=20
# Non-thinking: temp=0.7, top_p=0.8, top_k=20
# max_tokens: 32768 general, 81920 for complex math/code
# Model card recommends: temp=0.6, top_p=0.95, top_k=20
# We omit top_k to match vllm eval (which doesn't set it).
# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin).
[[model]]
name = "Qwen3.5 2B"
@@ -29,7 +28,8 @@ patterns = ["Qwen3.5-2B"]
reasoning = true
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Qwen3.5 9B"
@@ -37,7 +37,8 @@ patterns = ["Qwen3.5-9B"]
reasoning = true
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Qwen3.5 27B"
@@ -45,15 +46,17 @@ patterns = ["Qwen3.5-27B"]
reasoning = true
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Qwen3.5 35B A3B"
patterns = ["Qwen3.5-35B-A3B"]
reasoning = true
temperature = 1.0
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Qwen3.5 122B A10B"
@@ -61,7 +64,8 @@ patterns = ["Qwen3.5-122B-A10B"]
reasoning = true
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Qwen3.5 397B A17B"
@@ -69,12 +73,14 @@ patterns = ["Qwen3.5-397B-A17B"]
reasoning = true
temperature = 0.6
top_p = 0.95
max_tokens = 81920
enable_thinking = true
max_tokens = 121072
# ─── Qwen3 (Apr 2025) ───────────────────────────────────────────────
# Source: HuggingFace model cards (Qwen/Qwen3-*)
# Thinking: temp=0.6, top_p=0.95, top_k=20
# Non-thinking: temp=0.7, top_p=0.8, top_k=20
# Model card recommends: temp=0.6, top_p=0.95, top_k=20
# We omit top_k to match vllm eval (which doesn't set it).
# Non-thinking: temp=0.7, top_p=0.8
# max_tokens: 32768 general, 38912 for complex math/code
[[model]]
@@ -83,6 +89,7 @@ patterns = ["Qwen3-0.6B"]
reasoning = true
temperature = 0.6
top_p = 0.95
enable_thinking = true
max_tokens = 38912
[[model]]
@@ -91,6 +98,7 @@ patterns = ["Qwen3-30B-A3B"]
reasoning = true
temperature = 0.6
top_p = 0.95
enable_thinking = true
max_tokens = 38912
[[model]]
@@ -99,6 +107,7 @@ patterns = ["Qwen3-235B-A22B"]
reasoning = true
temperature = 0.6
top_p = 0.95
enable_thinking = true
max_tokens = 38912
[[model]]
@@ -107,6 +116,7 @@ patterns = ["Qwen3-Next-80B-A3B-Thinking"]
reasoning = true
temperature = 0.6
top_p = 0.95
enable_thinking = true
max_tokens = 38912
[[model]]
@@ -129,9 +139,9 @@ max_tokens = 16384
name = "Qwen3 Coder Next"
patterns = ["Qwen3-Coder-Next"]
reasoning = false
temperature = 0.7
top_p = 0.8
max_tokens = 16384
temperature = 1.0
top_p = 0.95
max_tokens = 121072
# ─── GPT-OSS (OpenAI) ───────────────────────────────────────────────
# Source: OpenAI GitHub README + HuggingFace discussion #21
@@ -165,10 +175,38 @@ patterns = ["DeepSeek-V3.1"]
reasoning = true
temperature = 0.0
[[model]]
name = "DeepSeek V3.2"
patterns = ["DeepSeek-V3.2"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
# ─── NVIDIA Nemotron ───────────────────────────────────────────────────
# Source: HuggingFace model cards
# All variants: temp=1.0, top_p=0.95, enable_thinking=true
[[model]]
name = "Nemotron Cascade 2 30B A3B"
patterns = ["Nemotron-Cascade-2-30B-A3B"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
[[model]]
name = "Nemotron 3 Super 120B A12B"
patterns = ["Nemotron-3-Super-120B-A12B", "NVIDIA-Nemotron-3-Super-120B-A12B"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
# ─── GLM (ZhipuAI / THUDM) ──────────────────────────────────────────
# Source: HuggingFace model cards + generation_config.json + docs.z.ai
# GLM 4.5+: temp=1.0, top_p=0.95
# Reasoning tasks: 131072 max_tokens; coding/SWE tasks: temp=0.7
# max_tokens=121072 to match vllm eval (131072 context - 10000 safety margin)
[[model]]
name = "GLM-5"
@@ -176,7 +214,8 @@ patterns = ["GLM-5"]
reasoning = true
temperature = 1.0
top_p = 0.95
max_tokens = 131072
enable_thinking = true
max_tokens = 121072
[[model]]
name = "GLM 4.5 Air"
@@ -191,7 +230,8 @@ patterns = ["GLM-4.7-"]
reasoning = true
temperature = 1.0
top_p = 0.95
max_tokens = 131072
enable_thinking = true
max_tokens = 121072
# Note: matches both GLM-4.7 and GLM-4.7-Flash
# ─── Kimi (Moonshot AI) ─────────────────────────────────────────────
@@ -213,7 +253,8 @@ patterns = ["Kimi-K2.5"]
reasoning = true
temperature = 1.0
top_p = 0.95
max_tokens = 131072
enable_thinking = true
max_tokens = 121072
[[model]]
name = "Kimi K2 Instruct"
@@ -223,7 +264,17 @@ temperature = 0.6
# ─── MiniMax ─────────────────────────────────────────────────────────
# Source: HuggingFace model cards + generation_config.json
# All models: temp=1.0, top_p=0.95, top_k=40
# All models: temp=1.0, top_p=0.95
# max_tokens=90000 to match vllm eval (100000 context - 10000 safety margin)
[[model]]
name = "MiniMax M2.7"
patterns = ["MiniMax-M2.7"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
max_tokens = 90000
[[model]]
name = "MiniMax M2.5"
@@ -231,6 +282,8 @@ patterns = ["MiniMax-M2.5"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
max_tokens = 90000
[[model]]
name = "MiniMax M2.1"
@@ -251,6 +304,8 @@ patterns = ["Step-3.5-Flash"]
reasoning = true
temperature = 1.0
top_p = 0.95
enable_thinking = true
max_tokens = 121072
# ─── Llama (Meta) ───────────────────────────────────────────────────
# Source: generation_config.json + meta-llama/llama-models generation.py

View File

@@ -35,6 +35,7 @@ from harness import (
ExoHttpError,
add_common_instance_args,
capture_cluster_snapshot,
find_existing_instance,
instance_id_from_instance,
node_ids_from_instance,
nodes_used_in_instance,
@@ -79,7 +80,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model", "*.jinja"],
)
)
@@ -277,28 +278,72 @@ def run_one_completion(
prompt_sizer: PromptSizer,
*,
use_prefix_cache: bool = False,
stream: bool = False,
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": tg,
"logprobs": False,
"use_prefix_cache": use_prefix_cache,
}
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
if not stream:
payload["stream"] = False
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
stats = out.get("generation_stats")
stats = out.get("generation_stats")
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
else:
tokens = 0
first_token_time = None
t0 = time.perf_counter()
text_parts: list[str] = []
stats = None
# Extract preview, handling None content (common for thinking models)
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
for raw_line in client.stream_bench_chat_completions(payload):
line = raw_line.strip()
if line.startswith(": generation_stats "):
with contextlib.suppress(json.JSONDecodeError):
stats = json.loads(line[len(": generation_stats ") :])
continue
if not line.startswith("data: "):
continue
data = line[6:]
if data == "[DONE]":
break
try:
chunk = json.loads(data)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if delta.get("content"):
if first_token_time is None:
first_token_time = time.perf_counter()
tokens += 1
text_parts.append(delta["content"])
except json.JSONDecodeError:
pass
elapsed = time.perf_counter() - t0
preview = "".join(text_parts)[:200]
if not stats:
ttft = (first_token_time - t0) if first_token_time else elapsed
gen_time = elapsed - ttft if tokens > 1 else elapsed
gen_tps = (tokens - 1) / gen_time if tokens > 1 and gen_time > 0 else 0.0
prompt_tps = pp_tokens / ttft if ttft > 0 else 0.0
stats = {
"prompt_tokens": pp_tokens,
"generation_tokens": tokens,
"prompt_tps": round(prompt_tps, 2),
"generation_tps": round(gen_tps, 2),
"peak_memory_usage": {"inBytes": 0},
}
return {
"elapsed_s": elapsed,
@@ -425,6 +470,11 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--stream",
action="store_true",
help="Use /bench/chat/completions with streaming SSE response (bench=True still applies: no EOS detection, no KV cache).",
)
ap.add_argument(
"--no-system-metrics",
action="store_true",
@@ -490,81 +540,124 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
# Optionally reuse a running instance for this model
reused_instance_id: str | None = None
if args.reuse_instance:
existing = find_existing_instance(client, full_model_id)
if existing:
reused_instance_id = existing
logger.info(f"Reusing existing instance {reused_instance_id}")
else:
logger.warning(
"--reuse-instance: no existing instance found, creating a new one"
)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
-nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
if reused_instance_id is not None:
# Use the existing instance directly — skip placement iteration
selected = []
download_duration_s = None
else:
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
if args.dry_run:
return 0
if not selected:
logger.error("No valid placements matched your filters.")
return 1
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.info("Planning phase: checking downloads...")
download_duration_s = run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if download_duration_s is not None:
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
else:
logger.info("Download: model already cached")
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
)
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
download_duration_s = run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
if download_duration_s is not None:
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
else:
logger.info("Download: model already cached")
cluster_snapshot = capture_cluster_snapshot(client)
all_rows: list[dict[str, Any]] = []
all_system_metrics: dict[str, dict[str, dict[str, float]]] = {}
# If reusing an existing instance, run a single benchmark pass against it
if reused_instance_id is not None:
selected = [None]
for preview in selected:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
created_instance = False
if preview is not None:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
# Delete any existing instances to free resources before placing
try:
state = client.request_json("GET", "/state")
for old_id in list(state.get("instances", {}).keys()):
logger.info(f"Deleting stale instance {old_id}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{old_id}")
if state.get("instances"):
time.sleep(2)
except Exception as e:
logger.warning(f"Failed to clean up stale instances: {e}")
time.sleep(1)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)
created_instance = True
else:
instance_id = reused_instance_id
sharding = "reused"
instance_meta = "reused"
n_nodes = 0
logger.info("=" * 80)
logger.info(f"Using existing instance {instance_id}")
sampler: SystemMetricsSampler | None = None
if not args.no_system_metrics:
if not args.no_system_metrics and preview is not None:
nids = node_ids_from_instance(instance)
sampler = SystemMetricsSampler(
ExoClient(args.host, args.port, timeout_s=30),
@@ -573,16 +666,20 @@ def main() -> int:
)
sampler.start()
def _do_one(c: ExoClient, pp: int, tg: int) -> tuple[dict[str, Any], int]:
return run_one_completion(
c,
full_model_id,
pp,
tg,
prompt_sizer,
use_prefix_cache=args.use_prefix_cache,
stream=args.stream,
)
try:
for i in range(args.warmup):
run_one_completion(
client,
full_model_id,
pp_list[0],
tg_list[0],
prompt_sizer,
use_prefix_cache=args.use_prefix_cache,
)
_do_one(client, pp_list[0], tg_list[0])
logger.debug(f" warmup {i + 1}/{args.warmup} done")
# If pp and tg lists have same length, run in tandem (zip)
@@ -604,14 +701,7 @@ def main() -> int:
# Sequential: single request
try:
inf_t0 = time.monotonic()
row, actual_pp_tokens = run_one_completion(
client,
full_model_id,
pp,
tg,
prompt_sizer,
use_prefix_cache=args.use_prefix_cache,
)
row, actual_pp_tokens = _do_one(client, pp, tg)
inference_windows.append((inf_t0, time.monotonic()))
except Exception as e:
logger.error(e)
@@ -760,10 +850,12 @@ def main() -> int:
gen_tps = per_req_tps * concurrency
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
def _peak_bytes(s: dict[str, Any]) -> float:
pm = s["peak_memory_usage"]
return pm.get("inBytes") or pm.get("in_bytes", 0)
peak = mean(_peak_bytes(x["stats"]) for x in runs)
summary = (
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
@@ -788,15 +880,16 @@ def main() -> int:
if placement_metrics:
all_system_metrics.update(placement_metrics)
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
if created_instance and instance_id is not None:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
time.sleep(5)
time.sleep(5)
output: dict[str, Any] = {"runs": all_rows}
if cluster_snapshot:

View File

@@ -47,6 +47,7 @@ from harness import (
ExoHttpError,
add_common_instance_args,
capture_cluster_snapshot,
find_existing_instance,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
@@ -62,6 +63,15 @@ from loguru import logger
# ---------------------------------------------------------------------------
MAX_RETRIES = 30
INSTANCE_HEALTH_CHECK_AFTER = (
3 # Check instance health after this many consecutive failures
)
class InstanceFailedError(RuntimeError):
"""Raised when the exo instance is detected as failed/gone."""
DEFAULT_MAX_TOKENS = 16_384
REASONING_MAX_TOKENS = 131_072
TEMPERATURE_NON_REASONING = 0.0
@@ -271,7 +281,7 @@ def run_humaneval_test(
@dataclass
class QuestionResult:
question_id: int
question_id: int | str
prompt: str
response: str
extracted_answer: str | None
@@ -281,7 +291,11 @@ class QuestionResult:
prompt_tokens: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
reasoning_content: str = ""
finish_reason: str = ""
elapsed_s: float = 0.0
power_watts: float = 0.0
energy_joules: float = 0.0
@dataclass
@@ -517,6 +531,10 @@ class ApiResult:
prompt_tokens: int
completion_tokens: int
reasoning_tokens: int
reasoning_content: str = ""
finish_reason: str = ""
power_watts: float = 0.0
energy_joules: float = 0.0
async def _call_api(
@@ -530,6 +548,9 @@ async def _call_api(
system_message: str | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
) -> ApiResult:
messages = []
if system_message:
@@ -546,6 +567,12 @@ async def _call_api(
body["reasoning_effort"] = reasoning_effort
if top_p is not None:
body["top_p"] = top_p
if top_k is not None:
body["top_k"] = top_k
if min_p is not None:
body["min_p"] = min_p
if enable_thinking is not None:
body["enable_thinking"] = enable_thinking
resp = await client.post(
f"{base_url}/v1/chat/completions",
@@ -554,19 +581,40 @@ async def _call_api(
)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
if not content or not content.strip():
choice = data["choices"][0]
message = choice["message"]
content = message.get("content") or ""
reasoning_content = message.get("reasoning_content") or ""
finish_reason = choice.get("finish_reason") or ""
# For thinking models, empty content is expected when finish_reason is "length"
if not content.strip() and finish_reason != "length" and not reasoning_content:
raise ValueError("Empty response from model")
usage = data.get("usage", {})
details = usage.get("completion_tokens_details", {})
power = data.get("power_usage") or {}
return ApiResult(
content=content,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
reasoning_tokens=details.get("reasoning_tokens", 0) if details else 0,
reasoning_content=reasoning_content,
finish_reason=finish_reason,
power_watts=power.get("total_avg_sys_power_watts", 0.0),
energy_joules=power.get("total_energy_joules", 0.0),
)
async def _check_instance_health(base_url: str) -> bool:
"""Return True if the exo instance is still reachable."""
try:
async with httpx.AsyncClient() as c:
resp = await c.get(f"{base_url}/models", timeout=5.0)
return resp.status_code == 200
except Exception:
return False
async def call_with_retries(
client: httpx.AsyncClient,
base_url: str,
@@ -578,8 +626,14 @@ async def call_with_retries(
system_message: str | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
instance_failed: asyncio.Event | None = None,
) -> ApiResult | None:
for attempt in range(MAX_RETRIES):
if instance_failed and instance_failed.is_set():
raise InstanceFailedError("Instance already marked as failed")
try:
return await _call_api(
client,
@@ -592,8 +646,30 @@ async def call_with_retries(
system_message,
reasoning_effort,
top_p,
top_k,
min_p,
enable_thinking,
)
except Exception as e:
is_conn_error = isinstance(
e,
(
httpx.ConnectError,
httpx.RemoteProtocolError,
ConnectionRefusedError,
OSError,
),
)
if (
is_conn_error
and attempt >= INSTANCE_HEALTH_CHECK_AFTER
and not await _check_instance_health(base_url)
):
if instance_failed:
instance_failed.set()
raise InstanceFailedError(
f"Instance is down after {attempt + 1} failures: {e}"
) from e
if attempt < MAX_RETRIES - 1:
wait = min(2**attempt, 60)
logger.warning(
@@ -618,10 +694,16 @@ async def evaluate_benchmark(
max_tokens: int,
concurrency: int = 1,
limit: int | None = None,
offset: int = 0,
timeout: float | None = None,
reasoning_effort: str | None = None,
top_p: float | None = None,
top_k: int | None = None,
min_p: float | None = None,
enable_thinking: bool | None = None,
difficulty: str | None = None,
checkpoint_path: Path | None = None,
release_version: str | None = None,
) -> list[QuestionResult]:
"""Run a benchmark. Returns per-question results."""
import datasets
@@ -652,7 +734,21 @@ async def evaluate_benchmark(
ds = ds.filter(lambda x: x["difficulty"] == difficulty)
logger.info(f"Filtered to {len(ds)} {difficulty} problems")
if release_version and "release_version" in ds.column_names:
ds = ds.filter(lambda x: x["release_version"] == release_version)
logger.info(
f"Filtered to {len(ds)} problems with release_version={release_version}"
)
# Sort by question_id to match LCB runner ordering (scenario_router.py:60).
# This ensures [offset:offset+limit] slices select the same problems as vllm.
if "question_id" in ds.column_names:
ds = ds.sort("question_id")
total = len(ds)
if offset > 0:
ds = ds.select(range(min(offset, total), total))
total = len(ds)
if limit and limit < total:
ds = ds.select(range(limit))
total = limit
@@ -660,6 +756,13 @@ async def evaluate_benchmark(
logger.info(
f"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, "
f"temperature={temperature}, max_tokens={max_tokens}"
+ (f", top_k={top_k}" if top_k is not None else "")
+ (f", min_p={min_p}" if min_p is not None else "")
+ (
f", enable_thinking={enable_thinking}"
if enable_thinking is not None
else ""
)
)
if config.kind == "code":
@@ -667,16 +770,64 @@ async def evaluate_benchmark(
"Code benchmarks execute model-generated code. Use a sandboxed environment."
)
# Load checkpoint for resume
checkpoint_data: dict[str | int, dict[str, Any]] = {}
if checkpoint_path and checkpoint_path.exists():
with open(checkpoint_path) as f:
for line in f:
entry = json.loads(line)
checkpoint_data[entry["question_id"]] = entry
logger.info(f"Loaded {len(checkpoint_data)} checkpointed results")
semaphore = asyncio.Semaphore(concurrency)
instance_failed = asyncio.Event()
results: list[QuestionResult | None] = [None] * total
completed = 0
lock = asyncio.Lock()
def _get_question_id(idx: int, doc: dict) -> str | int:
"""Get a stable question ID for checkpointing."""
if benchmark_name == "livecodebench":
return doc.get("question_id", idx)
elif benchmark_name == "humaneval":
return doc.get("task_id", idx)
return idx
async def process_question(
idx: int, doc: dict, http_client: httpx.AsyncClient
) -> None:
nonlocal completed
system_msg = None
question_id = _get_question_id(idx, doc)
# Bail out early if instance is already dead
if instance_failed.is_set():
return
# Check checkpoint
if question_id in checkpoint_data:
cached = checkpoint_data[question_id]
results[idx] = QuestionResult(
question_id=question_id,
prompt=cached.get("prompt", ""),
response=cached.get("response", ""),
extracted_answer=cached.get("extracted_answer"),
gold_answer=cached.get("gold_answer", ""),
correct=cached.get("correct", False),
error=cached.get("error"),
prompt_tokens=cached.get("prompt_tokens", 0),
completion_tokens=cached.get("completion_tokens", 0),
reasoning_tokens=cached.get("reasoning_tokens", 0),
reasoning_content=cached.get("reasoning_content", ""),
finish_reason=cached.get("finish_reason", ""),
elapsed_s=cached.get("elapsed_s", 0.0),
power_watts=cached.get("power_watts", 0.0),
energy_joules=cached.get("energy_joules", 0.0),
)
async with lock:
completed += 1
logger.info(f" [{completed}/{total}] {question_id} (cached)")
return
if benchmark_name == "gpqa_diamond":
prompt, gold = format_gpqa_question(doc, idx)
@@ -697,24 +848,50 @@ async def evaluate_benchmark(
raise ValueError(f"Unknown benchmark: {benchmark_name}")
async with semaphore:
if instance_failed.is_set():
return
t0 = time.monotonic()
api_result = await call_with_retries(
http_client,
base_url,
model,
prompt,
temperature,
max_tokens,
timeout,
system_message=system_msg,
reasoning_effort=reasoning_effort,
top_p=top_p,
)
try:
# Race the API call against the instance_failed event
api_task = asyncio.create_task(
call_with_retries(
http_client,
base_url,
model,
prompt,
temperature,
max_tokens,
timeout,
system_message=system_msg,
reasoning_effort=reasoning_effort,
top_p=top_p,
top_k=top_k,
min_p=min_p,
enable_thinking=enable_thinking,
instance_failed=instance_failed,
)
)
failed_waiter = asyncio.create_task(instance_failed.wait())
done, pending = await asyncio.wait(
[api_task, failed_waiter],
return_when=asyncio.FIRST_COMPLETED,
)
for p in pending:
p.cancel()
with contextlib.suppress(asyncio.CancelledError):
await p
if instance_failed.is_set() and api_task not in done:
logger.error(f"Instance failed, aborting {question_id}")
return
api_result = api_task.result()
except InstanceFailedError:
logger.error(f"Instance failed, skipping {question_id}")
return
elapsed = time.monotonic() - t0
if api_result is None:
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response="",
extracted_answer=None,
@@ -729,13 +906,17 @@ async def evaluate_benchmark(
"prompt_tokens": api_result.prompt_tokens,
"completion_tokens": api_result.completion_tokens,
"reasoning_tokens": api_result.reasoning_tokens,
"reasoning_content": api_result.reasoning_content,
"finish_reason": api_result.finish_reason,
"elapsed_s": elapsed,
"power_watts": api_result.power_watts,
"energy_joules": api_result.energy_joules,
}
if config.kind == "mc":
extracted = extract_mc_answer(response, valid_letters)
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=extracted,
@@ -749,7 +930,7 @@ async def evaluate_benchmark(
check_aime_answer(extracted, int(gold)) if extracted else False
)
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=extracted,
@@ -763,7 +944,7 @@ async def evaluate_benchmark(
code = extract_code_block(response, preserve_indent=keep_indent)
if code is None:
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
@@ -778,7 +959,7 @@ async def evaluate_benchmark(
code,
)
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer="pass" if passed else "fail",
@@ -793,7 +974,7 @@ async def evaluate_benchmark(
exec_meta["sample"],
)
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer="pass" if passed else "fail",
@@ -804,7 +985,7 @@ async def evaluate_benchmark(
)
else:
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
@@ -815,7 +996,7 @@ async def evaluate_benchmark(
)
else:
result = QuestionResult(
question_id=idx,
question_id=question_id,
prompt=prompt,
response=response,
extracted_answer=None,
@@ -827,24 +1008,82 @@ async def evaluate_benchmark(
results[idx] = result
# Write checkpoint (skip infra failures so they get retried on resume,
# but keep wrong answers — they are legitimate results)
if checkpoint_path is not None and result.response:
_write_checkpoint(checkpoint_path, result)
async with lock:
completed += 1
n = completed
if n % max(1, total // 20) == 0 or n == total:
correct_so_far = sum(1 for r in results if r is not None and r.correct)
answered = sum(1 for r in results if r is not None)
logger.info(
f" [{n}/{total}] {correct_so_far}/{answered} correct "
f"({correct_so_far / max(answered, 1):.1%})"
)
# Log progress
thinking_info = ""
if result.reasoning_content:
thinking_info = f", {len(result.reasoning_content)} chars thinking"
logger.info(
f" [{n}/{total}] {question_id}: {len(result.response)} chars{thinking_info}, "
f"tokens: {result.prompt_tokens}+{result.completion_tokens} "
f"[{result.finish_reason}]"
+ (f" {result.extracted_answer}" if result.extracted_answer else "")
)
async def _health_monitor() -> None:
"""Periodically check if the instance is still alive."""
# Wait a bit before first check to let things start
await asyncio.sleep(10)
while not instance_failed.is_set():
if not await _check_instance_health(base_url):
# Double-check to avoid false positives
await asyncio.sleep(2)
if not await _check_instance_health(base_url):
logger.error("Health monitor: instance is down!")
instance_failed.set()
return
await asyncio.sleep(5)
async with httpx.AsyncClient() as http_client:
monitor = asyncio.create_task(_health_monitor())
tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)]
await asyncio.gather(*tasks)
monitor.cancel()
with contextlib.suppress(asyncio.CancelledError):
await monitor
if instance_failed.is_set():
completed_count = sum(1 for r in results if r is not None)
logger.error(
f"Instance failed! Completed {completed_count}/{total} problems. "
f"Checkpoint saved — restart to resume remaining problems."
)
raise InstanceFailedError("Instance failed during evaluation")
return [r for r in results if r is not None]
def _write_checkpoint(path: Path, result: QuestionResult) -> None:
"""Append a single result to the JSONL checkpoint file."""
entry = {
"question_id": result.question_id,
"prompt": result.prompt,
"response": result.response,
"extracted_answer": result.extracted_answer,
"gold_answer": result.gold_answer,
"correct": result.correct,
"error": result.error,
"prompt_tokens": result.prompt_tokens,
"completion_tokens": result.completion_tokens,
"reasoning_tokens": result.reasoning_tokens,
"reasoning_content": result.reasoning_content,
"finish_reason": result.finish_reason,
"elapsed_s": round(result.elapsed_s, 2),
"power_watts": round(result.power_watts, 2),
"energy_joules": round(result.energy_joules, 2),
}
with open(path, "a") as f:
f.write(json.dumps(entry) + "\n")
# ---------------------------------------------------------------------------
# Results display
# ---------------------------------------------------------------------------
@@ -867,6 +1106,8 @@ def print_results(
total_elapsed = sum(r.elapsed_s for r in results)
wall_clock = max(r.elapsed_s for r in results) if results else 0.0
avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0
total_energy = sum(r.energy_joules for r in results)
avg_power = sum(r.power_watts for r in results) / max(total, 1)
label = f"[c={concurrency}] " if concurrency is not None else ""
print(f"\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})")
@@ -878,6 +1119,10 @@ def print_results(
f" | total time: {total_elapsed:.1f}s wall clock: {wall_clock:.1f}s"
)
print(tok_line)
if total_energy > 0:
print(
f" power: avg {avg_power:.1f}W | total energy: {total_energy:.1f}J ({total_energy / 3600:.2f}Wh)"
)
if errors:
print(f" API errors: {errors}")
if no_extract:
@@ -896,6 +1141,8 @@ def print_results(
"total_elapsed_s": total_elapsed,
"wall_clock_s": wall_clock,
"avg_gen_tps": avg_gen_tps,
"avg_power_watts": avg_power,
"total_energy_joules": total_energy,
}
@@ -1053,7 +1300,11 @@ def save_results(
"prompt_tokens": r.prompt_tokens,
"completion_tokens": r.completion_tokens,
"reasoning_tokens": r.reasoning_tokens,
"reasoning_content": r.reasoning_content,
"finish_reason": r.finish_reason,
"elapsed_s": round(r.elapsed_s, 2),
"power_watts": round(r.power_watts, 2),
"energy_joules": round(r.energy_joules, 2),
}
for r in results
],
@@ -1069,6 +1320,15 @@ def save_results(
# ---------------------------------------------------------------------------
def _checkpoint_path(
results_dir: str, benchmark: str, model: str, concurrency: int
) -> Path:
"""Return the JSONL checkpoint path for a benchmark run."""
out_dir = Path(results_dir) / model.replace("/", "_") / benchmark
out_dir.mkdir(parents=True, exist_ok=True)
return out_dir / f"c{concurrency}.checkpoint.jsonl"
def parse_int_list(values: list[str]) -> list[int]:
items: list[int] = []
for v in values:
@@ -1096,6 +1356,12 @@ def main() -> int:
default=None,
help="Max questions per benchmark (for fast iteration).",
)
ap.add_argument(
"--offset",
type=int,
default=0,
help="Skip first N questions (0-based).",
)
reasoning_group = ap.add_mutually_exclusive_group()
reasoning_group.add_argument(
@@ -1115,6 +1381,8 @@ def main() -> int:
"--temperature", type=float, default=None, help="Override temperature."
)
ap.add_argument("--top-p", type=float, default=None, help="Override top_p.")
ap.add_argument("--top-k", type=int, default=None, help="Override top_k.")
ap.add_argument("--min-p", type=float, default=None, help="Override min_p.")
ap.add_argument(
"--max-tokens", type=int, default=None, help="Override max output tokens."
)
@@ -1148,15 +1416,31 @@ def main() -> int:
choices=["easy", "medium", "hard"],
help="Filter by difficulty (livecodebench only). E.g. --difficulty hard",
)
ap.add_argument(
"--release-version",
default=None,
help="LCB dataset release version (livecodebench only). E.g. release_v5",
)
ap.add_argument(
"--results-dir",
default="eval_results",
help="Directory for result JSON files (default: eval_results).",
)
ap.add_argument(
"--skip-instance-setup",
"--enable-thinking",
type=lambda v: v.lower() in ("true", "1", "yes"),
default=None,
help="Enable thinking mode for models that support it.",
)
ap.add_argument(
"--force",
action="store_true",
help="Skip exo instance management (assumes model is already running).",
help="Discard any existing checkpoint and run from scratch.",
)
ap.add_argument(
"--keep-instance",
action="store_true",
help="Skip deleting the instance after eval (for chaining runs).",
)
args, _ = ap.parse_known_args()
@@ -1177,13 +1461,26 @@ def main() -> int:
# Instance management
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
instance_id: str | None = None
created_instance = False
if not args.skip_instance_setup:
short_id, full_model_id = resolve_model_short_id(
client,
args.model,
force_download=args.force_download,
)
_short_id, full_model_id = resolve_model_short_id(
client,
args.model,
force_download=args.force_download,
)
# Optionally reuse a running instance for this model
if args.reuse_instance:
existing = find_existing_instance(client, full_model_id)
if existing:
instance_id = existing
logger.info(f"Reusing existing instance {instance_id}")
else:
logger.warning(
"--reuse-instance: no existing instance found, creating a new one"
)
if instance_id is None:
selected = settle_and_fetch_placements(
client,
full_model_id,
@@ -1198,7 +1495,7 @@ def main() -> int:
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
-nodes_used_in_instance(p["instance"]),
nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
@@ -1225,6 +1522,18 @@ def main() -> int:
if download_duration is not None:
logger.info(f"Download: {download_duration:.1f}s")
# Delete any existing instances to free resources before placing
try:
state = client.request_json("GET", "/state")
for old_id in list(state.get("instances", {}).keys()):
logger.info(f"Deleting stale instance {old_id}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{old_id}")
if state.get("instances"):
time.sleep(2)
except Exception as e:
logger.warning(f"Failed to clean up stale instances: {e}")
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
@@ -1234,10 +1543,9 @@ def main() -> int:
client.request_json("DELETE", f"/instance/{instance_id}")
return 1
time.sleep(1)
cluster_snapshot = capture_cluster_snapshot(client)
else:
full_model_id = args.model
cluster_snapshot = None
created_instance = True
cluster_snapshot = capture_cluster_snapshot(client)
# Auto-detect reasoning from model config
model_config = load_model_config(full_model_id)
@@ -1291,16 +1599,57 @@ def main() -> int:
reasoning_effort = str(cfg["reasoning_effort"])
else:
reasoning_effort = "high" if is_reasoning else None
if args.top_k is not None:
top_k: int | None = args.top_k
elif "top_k" in cfg:
top_k = int(cfg["top_k"])
else:
top_k = None
if args.min_p is not None:
min_p: float | None = args.min_p
elif "min_p" in cfg:
min_p = float(cfg["min_p"])
else:
min_p = None
if args.enable_thinking is not None:
enable_thinking: bool | None = args.enable_thinking
elif "enable_thinking" in cfg:
enable_thinking = bool(cfg["enable_thinking"])
else:
enable_thinking = None
base_url = f"http://{args.host}:{args.port}"
logger.info(f"Model: {full_model_id}")
logger.info(
f"Settings: temperature={temperature}, max_tokens={max_tokens}, "
+ (f"top_p={top_p}, " if top_p is not None else "")
+ (f"top_k={top_k}, " if top_k is not None else "")
+ (f"min_p={min_p}, " if min_p is not None else "")
+ f"reasoning={'yes' if is_reasoning else 'no'}"
+ (f", reasoning_effort={reasoning_effort}" if reasoning_effort else "")
+ (
f", enable_thinking={enable_thinking}"
if enable_thinking is not None
else ""
)
)
# Common kwargs for evaluate_benchmark
eval_kwargs: dict[str, Any] = {
"reasoning_effort": reasoning_effort,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"enable_thinking": enable_thinking,
"difficulty": args.difficulty,
"offset": args.offset,
"release_version": args.release_version,
}
try:
if args.compare_concurrency:
concurrency_levels = parse_int_list(args.compare_concurrency)
@@ -1309,6 +1658,11 @@ def main() -> int:
for c in concurrency_levels:
logger.info(f"\n{'=' * 50}")
logger.info(f"Running {task_name} at concurrency={c}")
checkpoint_path = _checkpoint_path(
args.results_dir, task_name, full_model_id, c
)
if args.force and checkpoint_path.exists():
checkpoint_path.unlink()
results = asyncio.run(
evaluate_benchmark(
task_name,
@@ -1319,9 +1673,8 @@ def main() -> int:
concurrency=c,
limit=args.limit,
timeout=args.request_timeout,
reasoning_effort=reasoning_effort,
top_p=top_p,
difficulty=args.difficulty,
checkpoint_path=checkpoint_path,
**eval_kwargs,
)
)
if results:
@@ -1336,10 +1689,18 @@ def main() -> int:
cluster=cluster_snapshot,
)
results_by_c[c] = results
# Clean up checkpoint on success
if checkpoint_path.exists():
checkpoint_path.unlink()
if len(results_by_c) >= 2:
print_comparison(task_name, results_by_c)
else:
for task_name in task_names:
checkpoint_path = _checkpoint_path(
args.results_dir, task_name, full_model_id, args.num_concurrent
)
if args.force and checkpoint_path.exists():
checkpoint_path.unlink()
results = asyncio.run(
evaluate_benchmark(
task_name,
@@ -1350,9 +1711,8 @@ def main() -> int:
concurrency=args.num_concurrent,
limit=args.limit,
timeout=args.request_timeout,
reasoning_effort=reasoning_effort,
top_p=top_p,
difficulty=args.difficulty,
checkpoint_path=checkpoint_path,
**eval_kwargs,
)
)
if results:
@@ -1366,14 +1726,25 @@ def main() -> int:
scores,
cluster=cluster_snapshot,
)
# Clean up checkpoint on success
if checkpoint_path.exists():
checkpoint_path.unlink()
finally:
if instance_id is not None:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
if created_instance and instance_id is not None:
if args.keep_instance:
logger.info(f"Keeping instance {instance_id} (--keep-instance)")
else:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
try:
wait_for_instance_gone(client, instance_id)
except TimeoutError:
logger.warning(
f"Timed out waiting for instance {instance_id} to be deleted"
)
return 0

View File

@@ -6,6 +6,7 @@ import http.client
import json
import os
import time
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlencode
@@ -69,6 +70,30 @@ class ExoClient:
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def stream_bench_chat_completions(self, payload: dict[str, Any]) -> Iterator[str]:
"""POST /bench/chat/completions with stream=True, yielding raw SSE lines."""
payload = {**payload, "stream": True}
data = json.dumps(payload).encode("utf-8")
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
conn.request(
"POST",
"/bench/chat/completions",
body=data,
headers={
"Content-Type": "application/json",
"Accept": "text/event-stream",
},
)
resp = conn.getresponse()
if resp.status >= 400:
raw = resp.read().decode("utf-8", errors="replace")
raise ExoHttpError(resp.status, resp.reason, raw[:300])
for line in resp:
yield line.decode("utf-8", errors="replace")
finally:
conn.close()
def get_state_path(self, path: str) -> Any:
try:
return self.request_json("GET", f"/state/{path}")
@@ -462,9 +487,8 @@ def run_planning_phase(
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
# Wait for downloads (no timeout — poll until complete or failed)
while True:
all_done = True
for node_id in node_ids:
node_downloads = client.get_node_downloads(node_id) or []
@@ -514,9 +538,24 @@ def run_planning_phase(
if download_t0 is not None:
return time.perf_counter() - download_t0
return None
time.sleep(1)
time.sleep(10)
raise TimeoutError("Downloads did not complete in time")
def find_existing_instance(client: ExoClient, model_id: str) -> str | None:
"""Find an existing running instance for the given model."""
try:
state = client.request_json("GET", "/state")
except Exception:
return None
for inst_id, inst in state.get("instances", {}).items():
# Instance structure is nested: {"MlxJacclInstance": {"shardAssignments": {"modelId": ...}}}
for _inst_type, inner in inst.items():
if not isinstance(inner, dict):
continue
sa = inner.get("shardAssignments", {})
if sa.get("modelId") == model_id:
return inst_id
return None
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
@@ -572,3 +611,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
ap.add_argument(
"--reuse-instance",
action="store_true",
help="Reuse an existing running instance for this model instead of creating a new one.",
)

View File

@@ -336,7 +336,9 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/bench/chat/completions", response_model=None)(
self.bench_chat_completions
)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
@@ -829,7 +831,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
) -> BenchChatCompletionResponse:
) -> BenchChatCompletionResponse | StreamingResponse:
task_params = await chat_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
@@ -846,6 +848,22 @@ class API:
command = await self._send_text_generation_with_images(task_params)
if payload.stream:
return StreamingResponse(
with_sse_keepalive(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await self._collect_text_generation_with_stats(command.command_id)
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId: