mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
11 Commits
remove-nig
...
leo/fix-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c5efcb155 | ||
|
|
461806ef84 | ||
|
|
f56c67eda0 | ||
|
|
15f92ccbf5 | ||
|
|
c577ac5dd5 | ||
|
|
7b4c29d402 | ||
|
|
e3490441f5 | ||
|
|
4aba17581a | ||
|
|
4c1b385c79 | ||
|
|
f52e9b9aa0 | ||
|
|
f31c8e8165 |
1046
bench/eval_tool_calls.py
Normal file
1046
bench/eval_tool_calls.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,26 +4,29 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from harness import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -103,154 +106,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
@@ -269,184 +124,6 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
@@ -538,76 +215,12 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
add_common_instance_args(ap)
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
@@ -620,34 +233,6 @@ def main() -> int:
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
@@ -657,9 +242,6 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
@@ -674,17 +256,6 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -719,24 +290,10 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
@@ -760,16 +317,6 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
327
bench/harness.py
Normal file
327
bench/harness.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
raise
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if (m.get("name") or "").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def settle_and_fetch_placements(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
args: argparse.Namespace,
|
||||
settle_timeout: float = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
@@ -4,6 +4,7 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
|
||||
240
bench/scenarios.toml
Normal file
240
bench/scenarios.toml
Normal file
@@ -0,0 +1,240 @@
|
||||
# Tool definitions — each becomes an OpenAI function tool.
|
||||
# All scenarios get all tools unless they specify a `tools` list.
|
||||
|
||||
[tools.get_current_weather]
|
||||
description = "Get the current weather in a given location"
|
||||
required = ["location"]
|
||||
|
||||
[tools.get_current_weather.properties.location]
|
||||
type = "string"
|
||||
description = "City and state, e.g. San Francisco, CA"
|
||||
|
||||
[tools.get_current_weather.properties.unit]
|
||||
type = "string"
|
||||
enum = ["celsius", "fahrenheit"]
|
||||
description = "Temperature unit"
|
||||
|
||||
[tools.calculate]
|
||||
description = "Evaluate a mathematical expression and return the numeric result"
|
||||
required = ["expression"]
|
||||
|
||||
[tools.calculate.properties.expression]
|
||||
type = "string"
|
||||
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
|
||||
|
||||
[tools.search_products]
|
||||
description = "Search for products in a catalog by query, category, and price"
|
||||
required = ["query"]
|
||||
|
||||
[tools.search_products.properties.query]
|
||||
type = "string"
|
||||
description = "Search query string"
|
||||
|
||||
[tools.search_products.properties.category]
|
||||
type = "string"
|
||||
enum = ["electronics", "clothing", "food", "books"]
|
||||
description = "Product category to filter by"
|
||||
|
||||
[tools.search_products.properties.max_price]
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_simple"
|
||||
description = "Basic weather query -> get_current_weather"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather like in Tokyo right now?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_simple"
|
||||
description = "Math question -> calculate"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 3847 * 926 + 17293"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_with_filters"
|
||||
description = "Product search with category and price filter"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find me electronics under $50"
|
||||
|
||||
# -- Multi-turn: tool call then follow-up --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_multi_turn"
|
||||
description = "Weather query -> tool result -> natural language summary"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
temperature = "18C"
|
||||
condition = "partly cloudy"
|
||||
humidity = "65%"
|
||||
wind = "12 km/h NW"
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Paris?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_multi_turn"
|
||||
description = "Math query -> tool result -> model reports the answer"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
result = 491682
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 1847 * 263 + 5921"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_multi_turn"
|
||||
description = "Search query -> tool result -> model summarizes products"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Hands-On Machine Learning"
|
||||
price = 45.99
|
||||
rating = 4.8
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Deep Learning with Python"
|
||||
price = 39.99
|
||||
rating = 4.6
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Search for books about machine learning"
|
||||
|
||||
# -- Sequential tool calls --
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_same"
|
||||
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare the weather in Tokyo and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check both cities. Let me start with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_1"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_1"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_different"
|
||||
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll handle both. Let me check Berlin's weather first."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_2"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Berlin" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_2"
|
||||
content = '{"temperature": "12C", "condition": "rainy"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_three"
|
||||
description = "Two prior thinking+tool calls -> results -> model must make a third"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare weather in Tokyo, Paris, and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check all three cities. Starting with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_3"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_3"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "Got Tokyo. Now checking Paris."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_4"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Paris" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_joke"
|
||||
description = "Joke request should NOT trigger any tool"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Tell me a funny joke about cats."
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_factual"
|
||||
description = "Factual question answerable from training data"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What is the capital of Japan?"
|
||||
@@ -158,6 +158,7 @@
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
|
||||
@@ -4,10 +4,13 @@ from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
@@ -17,10 +18,22 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
)
|
||||
|
||||
|
||||
@@ -64,7 +77,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
def __init__(self, group: mx.distributed.Group | None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
@@ -156,15 +169,15 @@ class KVPrefixCache:
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
# Find best cache match
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length >= max_length - 1:
|
||||
best_index, best_length = i, length
|
||||
is_exact = True
|
||||
break
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
@@ -172,11 +185,12 @@ class KVPrefixCache:
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
has_ssm = has_non_kv_caches(self.caches[best_index])
|
||||
target = (max_length - 1) if is_exact and not has_ssm else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
if restore_snap is None and has_ssm:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
@@ -257,10 +271,21 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def _entry_length(
|
||||
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
if hasattr(c, "offset"):
|
||||
return c.offset
|
||||
# For CacheList
|
||||
if hasattr(c, "size"):
|
||||
return int(c.size()) # type: ignore
|
||||
return 0
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
return max(_entry_length(c) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
|
||||
@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
|
||||
|
||||
|
||||
def prefill(
|
||||
@@ -57,7 +57,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None = None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -133,7 +133,7 @@ def prefill(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -255,8 +255,8 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -436,9 +436,14 @@ def mlx_generate(
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
hit_ratio = (
|
||||
prefix_hit_length / len(all_prompt_tokens)
|
||||
if len(all_prompt_tokens) > 0
|
||||
else 0.0
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
|
||||
@@ -292,6 +292,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
elif "gpt-oss" in model_id_lower:
|
||||
return [200002, 200012]
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
@@ -588,7 +589,11 @@ def parse_gpt_oss(
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
stream.process(response.token)
|
||||
try:
|
||||
stream.process(response.token)
|
||||
except HarmonyError:
|
||||
logger.error("Encountered critical Harmony Error, returning early")
|
||||
return
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
@@ -103,7 +103,7 @@ class RunnerSupervisor:
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
@@ -123,7 +123,12 @@ def run_gpt_oss_pipeline_device(
|
||||
generated_text = ""
|
||||
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -194,6 +199,8 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,7 +142,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
@@ -161,9 +163,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
@@ -176,9 +180,11 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
# Exact match returns last token(s) — for models with SSM/rotating caches,
|
||||
# snapshot availability constrains how far back we can trim, so remaining
|
||||
# may be 1 or 2 tokens depending on the model.
|
||||
assert len(remaining_tokens) >= 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
@@ -194,10 +200,10 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
@@ -238,9 +244,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -276,9 +284,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -301,7 +311,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -318,6 +328,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
@@ -331,7 +342,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -347,6 +358,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -368,7 +380,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -395,6 +407,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
@@ -427,6 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
@@ -447,7 +461,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -462,6 +476,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -474,6 +489,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -484,7 +500,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -497,7 +513,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -522,7 +538,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -343,8 +343,16 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
|
||||
def contains(card: ModelCard, x: str):
|
||||
return x in card.model_id.lower()
|
||||
|
||||
glm_model_cards = [
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
card
|
||||
for card in await get_model_cards()
|
||||
if contains(card, "glm")
|
||||
and not contains(card, "-5")
|
||||
and not contains(card, "4.7")
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
162
src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py
Normal file
162
src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_gpt_oss
|
||||
|
||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||
# These are stable since they come from the model's vocabulary.
|
||||
_CHANNEL = 200005 # <|channel|>
|
||||
_START = 200006 # <|start|>
|
||||
_MESSAGE = 200008 # <|message|>
|
||||
_CALL = 200012 # <|call|>
|
||||
_END = 200007 # <|end|>
|
||||
_ASSISTANT = 173781 # "assistant"
|
||||
|
||||
# fmt: off
|
||||
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_A_TOKENS: list[tuple[int, str]] = [
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_B_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
|
||||
# Full analysis-then-tool-call as the model actually generates it.
|
||||
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(35644, "analysis"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(12845, "Let"),
|
||||
(668, " me"),
|
||||
(2411, " think"),
|
||||
(1078, " about"),
|
||||
(495, " this"),
|
||||
(13, "."),
|
||||
(_END, "<|end|>"),
|
||||
# Model generates a new message header for the tool call:
|
||||
(_START, "<|start|>"),
|
||||
(_ASSISTANT, "assistant"),
|
||||
*FORMAT_B_TOKENS,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _make_gen_responses(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse]:
|
||||
"""Build GenerationResponse list from (token_id, text) pairs."""
|
||||
responses: list[GenerationResponse] = []
|
||||
for i, (tid, text) in enumerate(tokens):
|
||||
is_last = i == len(tokens) - 1
|
||||
responses.append(
|
||||
GenerationResponse(
|
||||
text=text,
|
||||
token=tid,
|
||||
finish_reason="stop" if is_last else None,
|
||||
usage=None,
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse | ToolCallResponse]:
|
||||
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
|
||||
|
||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||
yield from _make_gen_responses(tokens)
|
||||
|
||||
return list(parse_gpt_oss(_gen()))
|
||||
|
||||
|
||||
def _get_tool_call(
|
||||
results: list[GenerationResponse | ToolCallResponse],
|
||||
) -> ToolCallResponse:
|
||||
"""Extract the single ToolCallResponse from results."""
|
||||
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
class TestParseGptOssRecipientPlacement:
|
||||
"""Both Harmony recipient placements must produce identical tool calls."""
|
||||
|
||||
def test_format_a_yields_tool_call(self):
|
||||
results = _collect(FORMAT_A_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_format_b_yields_tool_call(self):
|
||||
results = _collect(FORMAT_B_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_both_formats_produce_identical_tool_calls(self):
|
||||
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
|
||||
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
|
||||
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
|
||||
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
|
||||
|
||||
|
||||
class TestParseGptOssThinkingThenToolCall:
|
||||
"""Analysis (thinking) followed by a tool call must yield both."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
55
tests/eval_tool_calls.sh
Executable file
55
tests/eval_tool_calls.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
eval_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
|
||||
echo "running eval for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
|
||||
>>"./bench/$commit/${model//\//--}-eval.json"
|
||||
echo
|
||||
done
|
||||
691
tool_call_eval.py
Normal file
691
tool_call_eval.py
Normal file
@@ -0,0 +1,691 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||||
|
||||
Tests whether models correctly:
|
||||
- Trigger tool calls when appropriate
|
||||
- Return valid JSON arguments matching function schemas
|
||||
- Handle multi-turn tool use (call -> result -> final answer)
|
||||
- Avoid calling tools when unnecessary
|
||||
|
||||
Start exo with a model first, then run:
|
||||
uv run python tool_call_eval.py --model <model-id>
|
||||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature unit",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
CALCULATOR_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Evaluate a mathematical expression and return the numeric result",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The math expression to evaluate, e.g. '2 + 3 * 4'",
|
||||
},
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_products",
|
||||
"description": "Search for products in a catalog by query, category, and price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["electronics", "clothing", "food", "books"],
|
||||
"description": "Product category to filter by",
|
||||
},
|
||||
"max_price": {
|
||||
"type": "number",
|
||||
"description": "Maximum price in USD",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scenario:
|
||||
name: str
|
||||
description: str
|
||||
messages: list[dict[str, object]]
|
||||
tools: list[dict[str, object]]
|
||||
expect_tool_call: bool
|
||||
expected_function: str | None = None
|
||||
required_arg_keys: list[str] | None = None
|
||||
# For multi-turn: fake tool result to inject, then verify the follow-up.
|
||||
tool_result: str | None = None
|
||||
|
||||
|
||||
SCENARIOS = [
|
||||
# -- Should call a tool --------------------------------------------------
|
||||
Scenario(
|
||||
name="weather_simple",
|
||||
description="Basic weather query -> get_current_weather",
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather like in Tokyo right now?"}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
Scenario(
|
||||
name="calculator_simple",
|
||||
description="Math question -> calculate",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Use the calculator to compute 3847 * 926 + 17293",
|
||||
}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
),
|
||||
Scenario(
|
||||
name="search_with_filters",
|
||||
description="Product search with category and price filter",
|
||||
messages=[{"role": "user", "content": "Find me electronics under $50"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="search_products",
|
||||
required_arg_keys=["query"],
|
||||
),
|
||||
# -- Multi-turn: tool call then follow-up --------------------------------
|
||||
Scenario(
|
||||
name="weather_multi_turn",
|
||||
description="Weather query -> tool result -> natural language summary",
|
||||
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
tool_result=json.dumps(
|
||||
{
|
||||
"temperature": "18C",
|
||||
"condition": "partly cloudy",
|
||||
"humidity": "65%",
|
||||
"wind": "12 km/h NW",
|
||||
}
|
||||
),
|
||||
),
|
||||
Scenario(
|
||||
name="calculator_multi_turn",
|
||||
description="Math query -> tool result -> model reports the answer",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Use the calculator to compute 1847 * 263 + 5921",
|
||||
}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
tool_result=json.dumps({"result": 491682}),
|
||||
),
|
||||
Scenario(
|
||||
name="search_multi_turn",
|
||||
description="Search query -> tool result -> model summarizes products",
|
||||
messages=[
|
||||
{"role": "user", "content": "Search for books about machine learning"}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="search_products",
|
||||
required_arg_keys=["query"],
|
||||
tool_result=json.dumps(
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"name": "Hands-On Machine Learning",
|
||||
"price": 45.99,
|
||||
"rating": 4.8,
|
||||
},
|
||||
{
|
||||
"name": "Deep Learning with Python",
|
||||
"price": 39.99,
|
||||
"rating": 4.6,
|
||||
},
|
||||
]
|
||||
}
|
||||
),
|
||||
),
|
||||
# -- Sequential tool calls: thinking + tool call, NO final answer ----------
|
||||
# This is the critical scenario for the Harmony recipient placement fix.
|
||||
#
|
||||
# When an assistant message has both thinking content and a tool_call,
|
||||
# AND there is no subsequent final-answer assistant message, the Jinja
|
||||
# template renders BOTH the analysis and the tool call:
|
||||
#
|
||||
# <|start|>assistant<|channel|>analysis<|message|>thinking...<|end|>
|
||||
# <|start|>assistant to=functions.X<|channel|>commentary json<|message|>...<|call|>
|
||||
#
|
||||
# The two consecutive assistant messages have INCONSISTENT start patterns
|
||||
# (one has <|channel|> immediately, the other has to= first).
|
||||
# This confuses the model when it needs to generate its own tool call.
|
||||
#
|
||||
# The reformat fix makes both start with <|start|>assistant<|channel|>,
|
||||
# only differing in the channel name (analysis vs commentary).
|
||||
Scenario(
|
||||
name="chained_tool_calls_same",
|
||||
description="Thinking + weather(Tokyo) -> result -> model must call weather(London)",
|
||||
messages=[
|
||||
{"role": "user", "content": "Compare the weather in Tokyo and London."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check both cities. Let me start with Tokyo.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Tokyo"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
Scenario(
|
||||
name="chained_tool_calls_different",
|
||||
description="Thinking + weather(Berlin) -> result -> model must call calculator",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll handle both. Let me check Berlin's weather first.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Berlin"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_2",
|
||||
"content": json.dumps({"temperature": "12C", "condition": "rainy"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
),
|
||||
Scenario(
|
||||
name="chained_tool_calls_three",
|
||||
description="Two prior thinking+tool calls -> results -> model must make a third",
|
||||
messages=[
|
||||
{"role": "user", "content": "Compare weather in Tokyo, Paris, and London."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check all three cities. Starting with Tokyo.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Tokyo"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_3",
|
||||
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Got Tokyo. Now checking Paris.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_4",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Paris"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_4",
|
||||
"content": json.dumps({"temperature": "18C", "condition": "cloudy"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
# -- Should NOT call a tool ----------------------------------------------
|
||||
Scenario(
|
||||
name="no_tool_joke",
|
||||
description="Joke request should NOT trigger any tool",
|
||||
messages=[{"role": "user", "content": "Tell me a funny joke about cats."}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=False,
|
||||
),
|
||||
Scenario(
|
||||
name="no_tool_factual",
|
||||
description="Factual question answerable from training data",
|
||||
messages=[{"role": "user", "content": "What is the capital of Japan?"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=False,
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
phase: str # "tool_call" or "follow_up"
|
||||
passed: bool
|
||||
checks: dict[str, bool] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
latency_ms: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str | None]:
|
||||
"""Parse JSON arguments and check required keys exist."""
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
return False, f"Invalid JSON: {e}"
|
||||
if not isinstance(args, dict):
|
||||
return False, f"Expected dict, got {type(args).__name__}"
|
||||
missing = [k for k in required_keys if k not in args]
|
||||
if missing:
|
||||
return False, f"Missing keys: {missing}"
|
||||
return True, None
|
||||
|
||||
|
||||
def call_api(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
model: str,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]],
|
||||
timeout: float,
|
||||
) -> tuple[dict[str, object], float]:
|
||||
"""POST to /chat/completions, return (response_json, latency_ms)."""
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
body: dict[str, object] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
t0 = time.monotonic()
|
||||
resp = client.post(url, json=body, timeout=timeout)
|
||||
latency = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
return resp.json(), latency
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_scenario(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
model: str,
|
||||
scenario: Scenario,
|
||||
timeout: float,
|
||||
verbose: bool,
|
||||
) -> list[ScenarioResult]:
|
||||
results: list[ScenarioResult] = []
|
||||
|
||||
# --- Phase 1: initial request ---
|
||||
try:
|
||||
data, latency = call_api(
|
||||
client, base_url, model, scenario.messages, scenario.tools, timeout
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="tool_call",
|
||||
passed=False,
|
||||
error=f"API error: {e}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if verbose:
|
||||
print(f" response: {json.dumps(data, indent=2)}")
|
||||
|
||||
choice = data["choices"][0]
|
||||
finish_reason = choice.get("finish_reason")
|
||||
message = choice.get("message", {})
|
||||
tool_calls = message.get("tool_calls")
|
||||
content = message.get("content")
|
||||
|
||||
checks: dict[str, bool] = {}
|
||||
|
||||
if scenario.expect_tool_call:
|
||||
checks["finish_reason_tool_calls"] = finish_reason == "tool_calls"
|
||||
checks["has_tool_call"] = isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||
|
||||
args_err: str | None = None
|
||||
if checks["has_tool_call"]:
|
||||
tc = tool_calls[0]
|
||||
fn = tc.get("function", {})
|
||||
checks["correct_function"] = (
|
||||
scenario.expected_function is None
|
||||
or fn.get("name") == scenario.expected_function
|
||||
)
|
||||
if scenario.required_arg_keys:
|
||||
ok, args_err = validate_args(
|
||||
fn.get("arguments", ""), scenario.required_arg_keys
|
||||
)
|
||||
checks["valid_arguments"] = ok
|
||||
else:
|
||||
checks["valid_arguments"] = True
|
||||
else:
|
||||
checks["correct_function"] = False
|
||||
checks["valid_arguments"] = False
|
||||
args_err = "No tool call returned"
|
||||
|
||||
passed = all(checks.values())
|
||||
error = args_err if not passed else None
|
||||
else:
|
||||
checks["finish_reason_stop"] = finish_reason == "stop"
|
||||
checks["no_tool_call"] = tool_calls is None or len(tool_calls) == 0
|
||||
checks["has_content"] = isinstance(content, str) and len(content.strip()) > 0
|
||||
passed = all(checks.values())
|
||||
error = (
|
||||
None
|
||||
if passed
|
||||
else (
|
||||
f"finish_reason={finish_reason}, "
|
||||
f"tool_calls={'yes' if tool_calls else 'no'}, "
|
||||
f"content={'yes' if content else 'no'}"
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="tool_call",
|
||||
passed=passed,
|
||||
checks=checks,
|
||||
error=error,
|
||||
latency_ms=latency,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Phase 2: multi-turn follow-up ---
|
||||
if scenario.tool_result is not None and checks.get("has_tool_call"):
|
||||
tc = tool_calls[0]
|
||||
fn = tc.get("function", {})
|
||||
follow_up_messages: list[dict[str, object]] = list(scenario.messages) + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.get("id", "call_0"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fn.get("name", ""),
|
||||
"arguments": fn.get("arguments", "{}"),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.get("id", "call_0"),
|
||||
"content": scenario.tool_result,
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
data2, latency2 = call_api(
|
||||
client,
|
||||
base_url,
|
||||
model,
|
||||
follow_up_messages,
|
||||
scenario.tools,
|
||||
timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="follow_up",
|
||||
passed=False,
|
||||
error=f"API error: {e}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if verbose:
|
||||
print(f" follow_up response: {json.dumps(data2, indent=2)}")
|
||||
|
||||
choice2 = data2["choices"][0]
|
||||
message2 = choice2.get("message", {})
|
||||
checks2: dict[str, bool] = {}
|
||||
checks2["finish_reason_stop"] = choice2.get("finish_reason") == "stop"
|
||||
tc2 = message2.get("tool_calls")
|
||||
checks2["no_tool_call"] = tc2 is None or len(tc2) == 0
|
||||
c2 = message2.get("content")
|
||||
checks2["has_content"] = isinstance(c2, str) and len(c2.strip()) > 0
|
||||
|
||||
passed2 = all(checks2.values())
|
||||
error2 = None
|
||||
if not passed2:
|
||||
error2 = (
|
||||
f"finish_reason={choice2.get('finish_reason')}, "
|
||||
f"tool_calls={'yes' if tc2 else 'no'}, "
|
||||
f"content={'yes' if c2 else 'no'}"
|
||||
)
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="follow_up",
|
||||
passed=passed2,
|
||||
checks=checks2,
|
||||
error=error2,
|
||||
latency_ms=latency2,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Tool-calling eval for exo")
|
||||
parser.add_argument("--model", required=True, help="Model ID to test")
|
||||
parser.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout", type=float, default=120, help="Per-request timeout (seconds)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repeat each scenario N times"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scenarios", nargs="*", help="Run only these scenarios (by name)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print full API responses"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
scenarios = SCENARIOS
|
||||
if args.scenarios:
|
||||
scenarios = [s for s in SCENARIOS if s.name in args.scenarios]
|
||||
if not scenarios:
|
||||
print(f"No matching scenarios. Available: {[s.name for s in SCENARIOS]}")
|
||||
sys.exit(1)
|
||||
|
||||
base_url = f"http://{args.host}:{args.port}/v1"
|
||||
total_runs = len(scenarios) * args.repeat
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Endpoint: {base_url}")
|
||||
print(f"Scenarios: {len(scenarios)} x {args.repeat} = {total_runs} runs")
|
||||
print("=" * 64)
|
||||
|
||||
all_results: list[ScenarioResult] = []
|
||||
|
||||
with httpx.Client() as client:
|
||||
for run_idx in range(args.repeat):
|
||||
if args.repeat > 1:
|
||||
print(f"\n--- Run {run_idx + 1}/{args.repeat} ---")
|
||||
|
||||
for scenario in scenarios:
|
||||
print(f"\n {scenario.name}: {scenario.description}")
|
||||
|
||||
results = run_scenario(
|
||||
client,
|
||||
base_url,
|
||||
args.model,
|
||||
scenario,
|
||||
args.timeout,
|
||||
args.verbose,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
for r in results:
|
||||
status = "PASS" if r.passed else "FAIL"
|
||||
print(f" [{r.phase:>10}] {status} ({r.latency_ms:.0f}ms)")
|
||||
for check_name, check_ok in r.checks.items():
|
||||
mark = "+" if check_ok else "-"
|
||||
print(f" {mark} {check_name}")
|
||||
if r.error:
|
||||
print(f" ! {r.error}")
|
||||
|
||||
# --- Summary ---
|
||||
print(f"\n{'=' * 64}")
|
||||
|
||||
total = len(all_results)
|
||||
passed = sum(1 for r in all_results if r.passed)
|
||||
|
||||
tool_call_results = [r for r in all_results if r.phase == "tool_call"]
|
||||
follow_up_results = [r for r in all_results if r.phase == "follow_up"]
|
||||
tc_passed = sum(1 for r in tool_call_results if r.passed)
|
||||
fu_passed = sum(1 for r in follow_up_results if r.passed)
|
||||
avg_latency = sum(r.latency_ms for r in all_results) / total if total else 0
|
||||
|
||||
print(f"Total: {passed}/{total} passed ({100 * passed / total:.0f}%)")
|
||||
print(f"Tool call: {tc_passed}/{len(tool_call_results)} passed")
|
||||
if follow_up_results:
|
||||
print(f"Follow-up: {fu_passed}/{len(follow_up_results)} passed")
|
||||
print(f"Avg latency: {avg_latency:.0f}ms")
|
||||
|
||||
if passed < total:
|
||||
print("\nFailed:")
|
||||
for r in all_results:
|
||||
if not r.passed:
|
||||
print(f" - {r.name} [{r.phase}]: {r.error}")
|
||||
|
||||
sys.exit(0 if passed == total else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -447,6 +447,7 @@ name = "exo-bench"
|
||||
version = "0.1.0"
|
||||
source = { editable = "bench" }
|
||||
dependencies = [
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -456,6 +457,7 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
|
||||
Reference in New Issue
Block a user