Compare commits

...

11 Commits

Author SHA1 Message Date
rltakashige
9c5efcb155 Merge branch 'main' into leo/fix-tool-calling 2026-02-18 19:50:27 +00:00
Ryuichi Leo Takashige
461806ef84 Add CacheList 2026-02-18 19:39:01 +00:00
Ryuichi Leo Takashige
f56c67eda0 Revert 2026-02-18 19:23:52 +00:00
Ryuichi Leo Takashige
15f92ccbf5 Fix some tests 2026-02-18 18:54:56 +00:00
Ryuichi Leo Takashige
c577ac5dd5 Fix GLM 5 tool calling 2026-02-18 18:31:12 +00:00
Ryuichi Leo Takashige
7b4c29d402 put scenarios in separate toml file and prioritise tensor jaccl (most nodes) > single node > pipeline jaccl (most nodes) > pipeline ring (most nodes) > tensor ring (most nodes) 2026-02-18 16:58:06 +00:00
Ryuichi Leo Takashige
e3490441f5 Do auto placement with exo eval too. 2026-02-18 16:43:26 +00:00
Ryuichi Leo Takashige
4aba17581a Refactor exo bench and add tool call evals 2026-02-18 16:35:03 +00:00
Ryuichi Leo Takashige
4c1b385c79 Merge branch 'main' into leo/fix-tool-calling 2026-02-18 16:15:08 +00:00
Ryuichi Leo Takashige
f52e9b9aa0 the linkedin post was bs 2026-02-16 13:19:38 +00:00
Ryuichi Leo Takashige
f31c8e8165 fix gpt oss tool calling 2026-02-16 11:42:21 +00:00
19 changed files with 2656 additions and 513 deletions

1046
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,26 +4,29 @@ from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +106,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +124,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +215,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +233,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +242,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +256,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +290,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,16 +317,6 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

327
bench/harness.py Normal file
View File

@@ -0,0 +1,327 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

240
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,240 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -158,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -4,10 +4,13 @@ from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
@@ -17,10 +18,22 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
@@ -64,7 +77,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None = None):
def __init__(self, group: mx.distributed.Group | None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
@@ -156,15 +169,15 @@ class KVPrefixCache:
best_length = 0
is_exact = False
# Find best cache
# Find best cache match
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(prompt_tokens, cached_prompt)
if length >= max_length - 1:
best_index, best_length = i, length
is_exact = True
break
if length > best_length:
best_index, best_length = i, length
if length == max_length:
is_exact = True
best_index, best_length = i, length
break
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
@@ -172,11 +185,12 @@ class KVPrefixCache:
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
target = (max_length - 1) if is_exact else best_length
has_ssm = has_non_kv_caches(self.caches[best_index])
target = (max_length - 1) if is_exact and not has_ssm else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
# No usable snapshot — need fresh cache
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
if restore_snap is None and has_ssm:
return make_kv_cache(model), prompt_tokens, None
prompt_cache = deepcopy(self.caches[best_index])
@@ -257,10 +271,21 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(prompt_tokens)
def _entry_length(
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
if hasattr(c, "offset"):
return c.offset
# For CacheList
if hasattr(c, "size"):
return int(c.size()) # type: ignore
return 0
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
return max(_entry_length(c) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:

View File

@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
def prefill(
@@ -57,7 +57,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -133,7 +133,7 @@ def prefill(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -255,8 +255,8 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -436,9 +436,14 @@ def mlx_generate(
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index,

View File

@@ -292,6 +292,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None

View File

@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -588,7 +589,11 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel

View File

@@ -103,7 +103,7 @@ class RunnerSupervisor:
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(1)
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return

View File

@@ -123,7 +123,12 @@ def run_gpt_oss_pipeline_device(
generated_text = ""
for response in mlx_generate(
model=model, tokenizer=tokenizer, task=task, prompt=prompt
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:
@@ -194,6 +199,8 @@ def run_gpt_oss_tensor_parallel_device(
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.clear()
assert len(cache.prompts) == 0
@@ -142,7 +142,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
@@ -161,9 +163,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
@@ -176,9 +180,11 @@ class TestKVPrefixCacheWithModel:
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
# Exact match returns last token(s) — for models with SSM/rotating caches,
# snapshot availability constrains how far back we can trim, so remaining
# may be 1 or 2 tokens depending on the model.
assert len(remaining_tokens) >= 1
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
@@ -194,10 +200,10 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
@@ -238,9 +244,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -276,9 +284,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -301,7 +311,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -318,6 +328,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
generated_tokens += 1
@@ -331,7 +342,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -347,6 +358,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -368,7 +380,7 @@ class TestKVPrefixCacheWithModel:
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -395,6 +407,7 @@ class TestKVPrefixCacheWithModel:
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
first_gen_time = time.perf_counter() - t0
@@ -427,6 +440,7 @@ class TestKVPrefixCacheWithModel:
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
second_gen_time = time.perf_counter() - t0
@@ -447,7 +461,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -462,6 +476,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -474,6 +489,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -484,7 +500,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -497,7 +513,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -522,7 +538,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -343,8 +343,16 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
def contains(card: ModelCard, x: str):
return x in card.model_id.lower()
glm_model_cards = [
card for card in await get_model_cards() if "glm" in card.model_id.lower()
card
for card in await get_model_cards()
if contains(card, "glm")
and not contains(card, "-5")
and not contains(card, "4.7")
]
if not glm_model_cards:

View File

@@ -0,0 +1,162 @@
from collections.abc import Generator
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.runner import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
_CHANNEL = 200005 # <|channel|>
_START = 200006 # <|start|>
_MESSAGE = 200008 # <|message|>
_CALL = 200012 # <|call|>
_END = 200007 # <|end|>
_ASSISTANT = 173781 # "assistant"
# fmt: off
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_A_TOKENS: list[tuple[int, str]] = [
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_B_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
# Full analysis-then-tool-call as the model actually generates it.
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(35644, "analysis"),
(_MESSAGE, "<|message|>"),
(12845, "Let"),
(668, " me"),
(2411, " think"),
(1078, " about"),
(495, " this"),
(13, "."),
(_END, "<|end|>"),
# Model generates a new message header for the tool call:
(_START, "<|start|>"),
(_ASSISTANT, "assistant"),
*FORMAT_B_TOKENS,
]
# fmt: on
def _make_gen_responses(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse]:
"""Build GenerationResponse list from (token_id, text) pairs."""
responses: list[GenerationResponse] = []
for i, (tid, text) in enumerate(tokens):
is_last = i == len(tokens) - 1
responses.append(
GenerationResponse(
text=text,
token=tid,
finish_reason="stop" if is_last else None,
usage=None,
)
)
return responses
def _collect(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse | ToolCallResponse]:
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen()))
def _get_tool_call(
results: list[GenerationResponse | ToolCallResponse],
) -> ToolCallResponse:
"""Extract the single ToolCallResponse from results."""
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
return tool_calls[0]
class TestParseGptOssRecipientPlacement:
"""Both Harmony recipient placements must produce identical tool calls."""
def test_format_a_yields_tool_call(self):
results = _collect(FORMAT_A_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_format_b_yields_tool_call(self):
results = _collect(FORMAT_B_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_both_formats_produce_identical_tool_calls(self):
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
class TestParseGptOssThinkingThenToolCall:
"""Analysis (thinking) followed by a tool call must yield both."""
def test_thinking_then_tool_call(self):
results = _collect(THINKING_THEN_TOOL_TOKENS)
# Should have thinking tags + content + tool call
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
combined = "".join(text_parts)
assert "<think>" in combined
assert "</think>" in combined
assert "Let me think about this." in combined
# And the tool call
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert "Tokyo" in tc.tool_calls[0].arguments

55
tests/eval_tool_calls.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/usr/bin/env bash
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
sleep 1
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
done
wait
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
done
for host; do
echo "Waiting for $host..." 1>&2
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "Waiting 30s for cluster setup" 1>&2
sleep 30
echo "EXO loaded" 1>&2
eval_runner="${hosts[0]}"
mkdir -p "./bench/$commit"
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
echo "running eval for $model" 1>&2
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
>>"./bench/$commit/${model//\//--}-eval.json"
echo
done

691
tool_call_eval.py Normal file
View File

@@ -0,0 +1,691 @@
#!/usr/bin/env python3
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import dataclass, field
import httpx
# ---------------------------------------------------------------------------
# Tool definitions
# ---------------------------------------------------------------------------
WEATHER_TOOL = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit",
},
},
"required": ["location"],
},
},
}
CALCULATOR_TOOL = {
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression and return the numeric result",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The math expression to evaluate, e.g. '2 + 3 * 4'",
},
},
"required": ["expression"],
},
},
}
SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search_products",
"description": "Search for products in a catalog by query, category, and price",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query string",
},
"category": {
"type": "string",
"enum": ["electronics", "clothing", "food", "books"],
"description": "Product category to filter by",
},
"max_price": {
"type": "number",
"description": "Maximum price in USD",
},
},
"required": ["query"],
},
},
}
ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL]
# ---------------------------------------------------------------------------
# Scenarios
# ---------------------------------------------------------------------------
@dataclass
class Scenario:
name: str
description: str
messages: list[dict[str, object]]
tools: list[dict[str, object]]
expect_tool_call: bool
expected_function: str | None = None
required_arg_keys: list[str] | None = None
# For multi-turn: fake tool result to inject, then verify the follow-up.
tool_result: str | None = None
SCENARIOS = [
# -- Should call a tool --------------------------------------------------
Scenario(
name="weather_simple",
description="Basic weather query -> get_current_weather",
messages=[
{"role": "user", "content": "What's the weather like in Tokyo right now?"}
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="get_current_weather",
required_arg_keys=["location"],
),
Scenario(
name="calculator_simple",
description="Math question -> calculate",
messages=[
{
"role": "user",
"content": "Use the calculator to compute 3847 * 926 + 17293",
}
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="calculate",
required_arg_keys=["expression"],
),
Scenario(
name="search_with_filters",
description="Product search with category and price filter",
messages=[{"role": "user", "content": "Find me electronics under $50"}],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="search_products",
required_arg_keys=["query"],
),
# -- Multi-turn: tool call then follow-up --------------------------------
Scenario(
name="weather_multi_turn",
description="Weather query -> tool result -> natural language summary",
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="get_current_weather",
required_arg_keys=["location"],
tool_result=json.dumps(
{
"temperature": "18C",
"condition": "partly cloudy",
"humidity": "65%",
"wind": "12 km/h NW",
}
),
),
Scenario(
name="calculator_multi_turn",
description="Math query -> tool result -> model reports the answer",
messages=[
{
"role": "user",
"content": "Use the calculator to compute 1847 * 263 + 5921",
}
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="calculate",
required_arg_keys=["expression"],
tool_result=json.dumps({"result": 491682}),
),
Scenario(
name="search_multi_turn",
description="Search query -> tool result -> model summarizes products",
messages=[
{"role": "user", "content": "Search for books about machine learning"}
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="search_products",
required_arg_keys=["query"],
tool_result=json.dumps(
{
"results": [
{
"name": "Hands-On Machine Learning",
"price": 45.99,
"rating": 4.8,
},
{
"name": "Deep Learning with Python",
"price": 39.99,
"rating": 4.6,
},
]
}
),
),
# -- Sequential tool calls: thinking + tool call, NO final answer ----------
# This is the critical scenario for the Harmony recipient placement fix.
#
# When an assistant message has both thinking content and a tool_call,
# AND there is no subsequent final-answer assistant message, the Jinja
# template renders BOTH the analysis and the tool call:
#
# <|start|>assistant<|channel|>analysis<|message|>thinking...<|end|>
# <|start|>assistant to=functions.X<|channel|>commentary json<|message|>...<|call|>
#
# The two consecutive assistant messages have INCONSISTENT start patterns
# (one has <|channel|> immediately, the other has to= first).
# This confuses the model when it needs to generate its own tool call.
#
# The reformat fix makes both start with <|start|>assistant<|channel|>,
# only differing in the channel name (analysis vs commentary).
Scenario(
name="chained_tool_calls_same",
description="Thinking + weather(Tokyo) -> result -> model must call weather(London)",
messages=[
{"role": "user", "content": "Compare the weather in Tokyo and London."},
{
"role": "assistant",
"content": "I'll check both cities. Let me start with Tokyo.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": json.dumps({"location": "Tokyo"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
},
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="get_current_weather",
required_arg_keys=["location"],
),
Scenario(
name="chained_tool_calls_different",
description="Thinking + weather(Berlin) -> result -> model must call calculator",
messages=[
{
"role": "user",
"content": "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291.",
},
{
"role": "assistant",
"content": "I'll handle both. Let me check Berlin's weather first.",
"tool_calls": [
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": json.dumps({"location": "Berlin"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_2",
"content": json.dumps({"temperature": "12C", "condition": "rainy"}),
},
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="calculate",
required_arg_keys=["expression"],
),
Scenario(
name="chained_tool_calls_three",
description="Two prior thinking+tool calls -> results -> model must make a third",
messages=[
{"role": "user", "content": "Compare weather in Tokyo, Paris, and London."},
{
"role": "assistant",
"content": "I'll check all three cities. Starting with Tokyo.",
"tool_calls": [
{
"id": "call_3",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": json.dumps({"location": "Tokyo"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_3",
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
},
{
"role": "assistant",
"content": "Got Tokyo. Now checking Paris.",
"tool_calls": [
{
"id": "call_4",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": json.dumps({"location": "Paris"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_4",
"content": json.dumps({"temperature": "18C", "condition": "cloudy"}),
},
],
tools=ALL_TOOLS,
expect_tool_call=True,
expected_function="get_current_weather",
required_arg_keys=["location"],
),
# -- Should NOT call a tool ----------------------------------------------
Scenario(
name="no_tool_joke",
description="Joke request should NOT trigger any tool",
messages=[{"role": "user", "content": "Tell me a funny joke about cats."}],
tools=ALL_TOOLS,
expect_tool_call=False,
),
Scenario(
name="no_tool_factual",
description="Factual question answerable from training data",
messages=[{"role": "user", "content": "What is the capital of Japan?"}],
tools=ALL_TOOLS,
expect_tool_call=False,
),
]
# ---------------------------------------------------------------------------
# Result tracking
# ---------------------------------------------------------------------------
@dataclass
class ScenarioResult:
name: str
phase: str # "tool_call" or "follow_up"
passed: bool
checks: dict[str, bool] = field(default_factory=dict)
error: str | None = None
latency_ms: float = 0.0
# ---------------------------------------------------------------------------
# Evaluation helpers
# ---------------------------------------------------------------------------
def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str | None]:
"""Parse JSON arguments and check required keys exist."""
try:
args = json.loads(args_str)
except (json.JSONDecodeError, TypeError) as e:
return False, f"Invalid JSON: {e}"
if not isinstance(args, dict):
return False, f"Expected dict, got {type(args).__name__}"
missing = [k for k in required_keys if k not in args]
if missing:
return False, f"Missing keys: {missing}"
return True, None
def call_api(
client: httpx.Client,
base_url: str,
model: str,
messages: list[dict[str, object]],
tools: list[dict[str, object]],
timeout: float,
) -> tuple[dict[str, object], float]:
"""POST to /chat/completions, return (response_json, latency_ms)."""
url = f"{base_url.rstrip('/')}/chat/completions"
body: dict[str, object] = {
"model": model,
"messages": messages,
"tools": tools,
"temperature": 0.0,
"max_tokens": 4096,
}
t0 = time.monotonic()
resp = client.post(url, json=body, timeout=timeout)
latency = (time.monotonic() - t0) * 1000
resp.raise_for_status()
return resp.json(), latency
# ---------------------------------------------------------------------------
# Scenario runner
# ---------------------------------------------------------------------------
def run_scenario(
client: httpx.Client,
base_url: str,
model: str,
scenario: Scenario,
timeout: float,
verbose: bool,
) -> list[ScenarioResult]:
results: list[ScenarioResult] = []
# --- Phase 1: initial request ---
try:
data, latency = call_api(
client, base_url, model, scenario.messages, scenario.tools, timeout
)
except Exception as e:
results.append(
ScenarioResult(
name=scenario.name,
phase="tool_call",
passed=False,
error=f"API error: {e}",
)
)
return results
if verbose:
print(f" response: {json.dumps(data, indent=2)}")
choice = data["choices"][0]
finish_reason = choice.get("finish_reason")
message = choice.get("message", {})
tool_calls = message.get("tool_calls")
content = message.get("content")
checks: dict[str, bool] = {}
if scenario.expect_tool_call:
checks["finish_reason_tool_calls"] = finish_reason == "tool_calls"
checks["has_tool_call"] = isinstance(tool_calls, list) and len(tool_calls) > 0
args_err: str | None = None
if checks["has_tool_call"]:
tc = tool_calls[0]
fn = tc.get("function", {})
checks["correct_function"] = (
scenario.expected_function is None
or fn.get("name") == scenario.expected_function
)
if scenario.required_arg_keys:
ok, args_err = validate_args(
fn.get("arguments", ""), scenario.required_arg_keys
)
checks["valid_arguments"] = ok
else:
checks["valid_arguments"] = True
else:
checks["correct_function"] = False
checks["valid_arguments"] = False
args_err = "No tool call returned"
passed = all(checks.values())
error = args_err if not passed else None
else:
checks["finish_reason_stop"] = finish_reason == "stop"
checks["no_tool_call"] = tool_calls is None or len(tool_calls) == 0
checks["has_content"] = isinstance(content, str) and len(content.strip()) > 0
passed = all(checks.values())
error = (
None
if passed
else (
f"finish_reason={finish_reason}, "
f"tool_calls={'yes' if tool_calls else 'no'}, "
f"content={'yes' if content else 'no'}"
)
)
results.append(
ScenarioResult(
name=scenario.name,
phase="tool_call",
passed=passed,
checks=checks,
error=error,
latency_ms=latency,
)
)
# --- Phase 2: multi-turn follow-up ---
if scenario.tool_result is not None and checks.get("has_tool_call"):
tc = tool_calls[0]
fn = tc.get("function", {})
follow_up_messages: list[dict[str, object]] = list(scenario.messages) + [
{
"role": "assistant",
"tool_calls": [
{
"id": tc.get("id", "call_0"),
"type": "function",
"function": {
"name": fn.get("name", ""),
"arguments": fn.get("arguments", "{}"),
},
}
],
},
{
"role": "tool",
"tool_call_id": tc.get("id", "call_0"),
"content": scenario.tool_result,
},
]
try:
data2, latency2 = call_api(
client,
base_url,
model,
follow_up_messages,
scenario.tools,
timeout,
)
except Exception as e:
results.append(
ScenarioResult(
name=scenario.name,
phase="follow_up",
passed=False,
error=f"API error: {e}",
)
)
return results
if verbose:
print(f" follow_up response: {json.dumps(data2, indent=2)}")
choice2 = data2["choices"][0]
message2 = choice2.get("message", {})
checks2: dict[str, bool] = {}
checks2["finish_reason_stop"] = choice2.get("finish_reason") == "stop"
tc2 = message2.get("tool_calls")
checks2["no_tool_call"] = tc2 is None or len(tc2) == 0
c2 = message2.get("content")
checks2["has_content"] = isinstance(c2, str) and len(c2.strip()) > 0
passed2 = all(checks2.values())
error2 = None
if not passed2:
error2 = (
f"finish_reason={choice2.get('finish_reason')}, "
f"tool_calls={'yes' if tc2 else 'no'}, "
f"content={'yes' if c2 else 'no'}"
)
results.append(
ScenarioResult(
name=scenario.name,
phase="follow_up",
passed=passed2,
checks=checks2,
error=error2,
latency_ms=latency2,
)
)
return results
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Tool-calling eval for exo")
parser.add_argument("--model", required=True, help="Model ID to test")
parser.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
parser.add_argument(
"--port",
type=int,
default=int(os.environ.get("EXO_PORT", "52415")),
)
parser.add_argument(
"--timeout", type=float, default=120, help="Per-request timeout (seconds)"
)
parser.add_argument(
"--repeat", type=int, default=1, help="Repeat each scenario N times"
)
parser.add_argument(
"--scenarios", nargs="*", help="Run only these scenarios (by name)"
)
parser.add_argument(
"--verbose", action="store_true", help="Print full API responses"
)
args = parser.parse_args()
scenarios = SCENARIOS
if args.scenarios:
scenarios = [s for s in SCENARIOS if s.name in args.scenarios]
if not scenarios:
print(f"No matching scenarios. Available: {[s.name for s in SCENARIOS]}")
sys.exit(1)
base_url = f"http://{args.host}:{args.port}/v1"
total_runs = len(scenarios) * args.repeat
print(f"Model: {args.model}")
print(f"Endpoint: {base_url}")
print(f"Scenarios: {len(scenarios)} x {args.repeat} = {total_runs} runs")
print("=" * 64)
all_results: list[ScenarioResult] = []
with httpx.Client() as client:
for run_idx in range(args.repeat):
if args.repeat > 1:
print(f"\n--- Run {run_idx + 1}/{args.repeat} ---")
for scenario in scenarios:
print(f"\n {scenario.name}: {scenario.description}")
results = run_scenario(
client,
base_url,
args.model,
scenario,
args.timeout,
args.verbose,
)
all_results.extend(results)
for r in results:
status = "PASS" if r.passed else "FAIL"
print(f" [{r.phase:>10}] {status} ({r.latency_ms:.0f}ms)")
for check_name, check_ok in r.checks.items():
mark = "+" if check_ok else "-"
print(f" {mark} {check_name}")
if r.error:
print(f" ! {r.error}")
# --- Summary ---
print(f"\n{'=' * 64}")
total = len(all_results)
passed = sum(1 for r in all_results if r.passed)
tool_call_results = [r for r in all_results if r.phase == "tool_call"]
follow_up_results = [r for r in all_results if r.phase == "follow_up"]
tc_passed = sum(1 for r in tool_call_results if r.passed)
fu_passed = sum(1 for r in follow_up_results if r.passed)
avg_latency = sum(r.latency_ms for r in all_results) / total if total else 0
print(f"Total: {passed}/{total} passed ({100 * passed / total:.0f}%)")
print(f"Tool call: {tc_passed}/{len(tool_call_results)} passed")
if follow_up_results:
print(f"Follow-up: {fu_passed}/{len(follow_up_results)} passed")
print(f"Avg latency: {avg_latency:.0f}ms")
if passed < total:
print("\nFailed:")
for r in all_results:
if not r.passed:
print(f" - {r.name} [{r.phase}]: {r.error}")
sys.exit(0 if passed == total else 1)
if __name__ == "__main__":
main()

2
uv.lock generated
View File

@@ -447,6 +447,7 @@ name = "exo-bench"
version = "0.1.0"
source = { editable = "bench" }
dependencies = [
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -456,6 +457,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "jinja2", specifier = ">=3.1.0" },
{ name = "loguru", specifier = ">=0.7.3" },