Compare commits

..

6 Commits

Author SHA1 Message Date
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
rltakashige
24e99ce197 Cleanup mistakes (#1537)
Oops
2026-02-18 22:05:26 +00:00
Alex Cheema
315992549b fix: unblock MpReceiver.close() to prevent shutdown hang (#1511)
## Summary

- `MpReceiver.close()` did not unblock threads stuck on `queue.get()` in
`receive_async()`, causing abandoned threads (via
`abandon_on_cancel=True`) to keep the Python process alive indefinitely
after tests pass
- This caused the `aarch64-darwin` CI jobs in PR #1462 to hang for ~6
hours until the GitHub Actions timeout killed them
- Sends an `_MpEndOfStream` sentinel before closing the buffer,
mirroring what `MpSender.close()` already does

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt` — 0 changed
- [x] `uv run pytest` — 188 passed, 1 skipped in 12s (no hang)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-18 21:59:02 +00:00
Alex Cheema
ce5a65d3b9 Add MiniMax M2.5 model cards (#1514)
## Summary
- Adds model cards for MiniMax M2.5 in three quantizations: 4bit (~129
GB), 6bit (~186 GB), 8bit (~243 GB)
- No code changes needed — `MiniMaxM2ForCausalLM` is already in the
tensor parallel whitelist and `MiniMaxShardingStrategy` is already
implemented in `auto_parallel.py`
- Credit to @vskiwi for confirming MiniMax M2.5 works out of the box
with existing code

Closes #1480

## Test plan
- [x] `basedpyright` passes with 0 errors
- [x] `ruff check` passes
- [x] `pytest` passes (260 passed, 1 skipped)
- [ ] Verify MiniMax M2.5 models appear in model selector on dashboard

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 21:11:13 +00:00
rltakashige
c2f2111b88 Fix tool calling (#1529)
## Motivation

GPT OSS tool calling issues.

## Changes

Fixes those and adds a bunch of evals for tool calling.
Fixes GLM5 prefix caching, where CacheList wasn't getting handled
properly.
Extracts a bunch of the setup functionality of exo bench to a harness
that can be reused elsewhere, such as in the tool calling eval.

## Test Plan
### Automated Testing
Let's run the evals for all models
2026-02-18 20:29:18 +00:00
35 changed files with 2407 additions and 581 deletions

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict, strict: bool = ...) -> Module:
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -44,11 +47,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable,
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -80,9 +83,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -118,7 +121,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

1046
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,47 @@
# type: ignore
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +121,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +139,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +230,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +248,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +257,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +271,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +305,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,16 +332,6 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

327
bench/harness.py Normal file
View File

@@ -0,0 +1,327 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

240
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,240 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -158,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,6 +1,7 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -21,7 +22,7 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -45,8 +46,8 @@ class DownloadCoordinator:
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -122,14 +123,17 @@ class DownloadCoordinator:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True
@@ -294,16 +298,15 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self._system_id,
origin=self.node_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)

View File

@@ -1,10 +1,11 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
from typing import Iterator, Self
import anyio
from anyio.abc import TaskGroup
@@ -37,11 +38,12 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> Self:
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -55,6 +57,9 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -63,6 +68,7 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
@@ -89,6 +95,7 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -126,6 +133,7 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
@@ -205,6 +213,7 @@ class Node:
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -215,6 +224,7 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
@@ -232,6 +242,7 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -131,7 +131,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -145,6 +145,7 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -183,7 +184,6 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -234,7 +234,6 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -548,7 +547,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -893,7 +892,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -979,7 +978,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1294,8 +1293,18 @@ class API:
return total_available
async def get_models(self) -> ModelList:
"""Returns list of available models."""
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
"""Returns list of available models, optionally filtered by being downloaded."""
cards = await get_model_cards()
if status == "downloaded":
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [c for c in cards if c.model_id in downloaded_model_ids]
return ModelList(
data=[
ModelListModel(
@@ -1313,7 +1322,7 @@ class API:
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in cards
]
)
@@ -1410,7 +1419,7 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
@@ -1474,12 +1483,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self._system_id, command=command)
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(

View File

@@ -29,7 +29,7 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -90,8 +90,7 @@ class Master:
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -289,7 +288,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id, command=cmd
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -416,7 +415,7 @@ class Master:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=self._system_id,
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -429,7 +428,7 @@ class Master:
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self._system_id,
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,
event=event.event,

View File

@@ -15,7 +15,7 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -75,12 +75,13 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=SystemId("Worker"),
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
@@ -107,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
PlaceInstance(
command_id=CommandId(),
@@ -132,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: DownloadCommand

View File

@@ -25,10 +25,6 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -166,6 +166,6 @@ class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
origin: NodeId
session: SessionId
event: Event

View File

@@ -4,10 +4,13 @@ from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]

View File

@@ -1,3 +1,4 @@
import contextlib
import multiprocessing as mp
from dataclasses import dataclass, field
from math import inf
@@ -132,7 +133,8 @@ class MpSender[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
self._state.buffer.put(_MpEndOfStream())
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==
@@ -204,6 +206,8 @@ class MpReceiver[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
@@ -17,10 +18,22 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
@@ -64,7 +77,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None = None):
def __init__(self, group: mx.distributed.Group | None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
@@ -156,15 +169,15 @@ class KVPrefixCache:
best_length = 0
is_exact = False
# Find best cache
# Find best cache match
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(prompt_tokens, cached_prompt)
if length >= max_length - 1:
best_index, best_length = i, length
is_exact = True
break
if length > best_length:
best_index, best_length = i, length
if length == max_length:
is_exact = True
best_index, best_length = i, length
break
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
@@ -172,11 +185,12 @@ class KVPrefixCache:
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
target = (max_length - 1) if is_exact else best_length
has_ssm = has_non_kv_caches(self.caches[best_index])
target = (max_length - 1) if is_exact and not has_ssm else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
# No usable snapshot — need fresh cache
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
if restore_snap is None and has_ssm:
return make_kv_cache(model), prompt_tokens, None
prompt_cache = deepcopy(self.caches[best_index])
@@ -257,10 +271,21 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(prompt_tokens)
def _entry_length(
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
if hasattr(c, "offset"):
return c.offset
# For CacheList
if hasattr(c, "size"):
return int(c.size()) # type: ignore
return 0
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
return max(_entry_length(c) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:

View File

@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
def prefill(
@@ -57,7 +57,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -133,7 +133,7 @@ def prefill(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -255,8 +255,8 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -436,9 +436,14 @@ def mlx_generate(
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index,

View File

@@ -292,6 +292,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None

View File

@@ -1,6 +1,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -16,7 +17,7 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
@@ -63,12 +64,14 @@ class Worker:
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
@@ -83,8 +86,6 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -131,7 +132,7 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
@@ -211,7 +212,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id,
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -311,7 +312,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
origin=self.node_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -338,16 +339,15 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self._system_id,
origin=self.node_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe

View File

@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -588,7 +589,11 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel

View File

@@ -103,7 +103,7 @@ class RunnerSupervisor:
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(1)
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return

View File

@@ -123,7 +123,12 @@ def run_gpt_oss_pipeline_device(
generated_text = ""
for response in mlx_generate(
model=model, tokenizer=tokenizer, task=task, prompt=prompt
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:
@@ -194,6 +199,8 @@ def run_gpt_oss_tensor_parallel_device(
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.clear()
assert len(cache.prompts) == 0
@@ -142,7 +142,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
@@ -161,9 +163,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
@@ -176,9 +180,11 @@ class TestKVPrefixCacheWithModel:
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
# Exact match returns last token(s) — for models with SSM/rotating caches,
# snapshot availability constrains how far back we can trim, so remaining
# may be 1 or 2 tokens depending on the model.
assert len(remaining_tokens) >= 1
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
@@ -194,10 +200,10 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
@@ -238,9 +244,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -276,9 +284,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -301,7 +311,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -318,6 +328,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
generated_tokens += 1
@@ -331,7 +342,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -347,6 +358,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -368,7 +380,7 @@ class TestKVPrefixCacheWithModel:
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -395,6 +407,7 @@ class TestKVPrefixCacheWithModel:
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
first_gen_time = time.perf_counter() - t0
@@ -427,6 +440,7 @@ class TestKVPrefixCacheWithModel:
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
second_gen_time = time.perf_counter() - t0
@@ -447,7 +461,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -462,6 +476,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -474,6 +489,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -484,7 +500,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -497,7 +513,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -522,7 +538,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -0,0 +1,297 @@
import copy
import gc
import importlib
import json
import shutil
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
import pytest
from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
load_tokenizer_for_model_id,
)
HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
# ── Config reduction ──────────────────────────────────────────────────────── #
_REDUCE = {
"num_hidden_layers": 4,
"hidden_size": 256,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 512,
"moe_intermediate_size": 128,
"num_experts": 4,
"num_experts_per_tok": 2,
"n_routed_experts": 4,
"num_local_experts": 4,
"num_nextn_predict_layers": 0,
"first_k_dense_replace": 0,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"num_attention_groups": 4,
}
def _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:
result = dict(cfg)
for key, val in _REDUCE.items():
if key in result:
result[key] = val
return result
def _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:
result = _reduce_dict(cfg)
n_layers = cast(int, result.get("num_hidden_layers", 4))
if "text_config" in result and isinstance(result["text_config"], dict):
result["text_config"] = _reduce_dict(
cast(dict[str, Any], result["text_config"])
)
tc: dict[str, Any] = result["text_config"]
if "num_nextn_predict_layers" in tc:
tc["num_nextn_predict_layers"] = 0
if "layer_types" in result and isinstance(result["layer_types"], list):
result["layer_types"] = result["layer_types"][:n_layers]
if "attention_other_setting" in result and isinstance(
result["attention_other_setting"], dict
):
aos: dict[str, Any] = dict(
cast(dict[str, Any], result["attention_other_setting"])
)
if "num_attention_heads" in aos:
aos["num_attention_heads"] = result.get("num_attention_heads", 4)
if "num_attention_groups" in aos:
aos["num_attention_groups"] = result.get(
"num_attention_groups", cast(int, aos["num_attention_groups"])
)
result["attention_other_setting"] = aos
if "moe_layers_enum" in result and isinstance(result["moe_layers_enum"], str):
indices = [int(x) for x in result["moe_layers_enum"].split(",") if x.strip()]
valid = [i for i in indices if i < n_layers]
result["moe_layers_enum"] = ",".join(str(i) for i in valid) if valid else ""
return result
# ── Helpers ───────────────────────────────────────────────────────────────── #
def _find_snapshot(hub_name: str) -> Path | None:
model_dir = HF_CACHE / f"models--mlx-community--{hub_name}"
snaps = model_dir / "snapshots"
if not snaps.exists():
return None
children = sorted(snaps.iterdir())
return children[0] if children else None
def _copy_tokenizer(src: Path, dst: Path) -> None:
for f in src.iterdir():
name = f.name
if (
"tokeniz" in name.lower()
or "tiktoken" in name.lower()
or name.startswith("vocab")
or name.endswith(".jinja")
or "tool_declaration" in name
) and f.is_file():
shutil.copy2(f, dst / name)
def _build_model(module_name: str, cfg: dict[str, Any]) -> Model:
mod = importlib.import_module(f"mlx_lm.models.{module_name}")
args = mod.ModelArgs.from_dict(cfg) # pyright: ignore[reportAny]
model: nn.Module = mod.Model(args) # pyright: ignore[reportAny]
flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))
random_weights = [
(k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat
]
model.update(cast(dict[str, Any], tree_unflatten(random_weights)))
mx.eval(model.parameters())
return cast(Model, model)
def _collect_tokens(
model: Model,
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None,
) -> list[int]:
tokens: list[int] = []
for resp in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
tokens.append(resp.token)
if resp.finish_reason is not None:
break
return tokens
# ── Architecture definitions ──────────────────────────────────────────────── #
@dataclass(frozen=True)
class ArchSpec:
name: str
hub_name: str
module: str
tokenizer_hub: str | None = None # fallback for models without bundled tokenizer
ARCHITECTURES: list[ArchSpec] = [
ArchSpec("llama", "Llama-3.2-1B-Instruct-4bit", "llama"),
ArchSpec("glm_moe_dsa", "GLM-5-MXFP4-Q8", "glm_moe_dsa"),
ArchSpec(
"glm4_moe", "GLM-4.5-Air-8bit", "glm4_moe", tokenizer_hub="GLM-4.7-8bit-gs32"
),
ArchSpec(
"glm4_moe_lite",
"GLM-4.7-Flash-8bit",
"glm4_moe_lite",
tokenizer_hub="GLM-4.7-8bit-gs32",
),
ArchSpec("glm4_moe_47", "GLM-4.7-8bit-gs32", "glm4_moe"),
ArchSpec("qwen3", "Qwen3-4B-Instruct-2507-4bit", "qwen3"),
ArchSpec("qwen3_moe", "Qwen3-30B-A3B-4bit", "qwen3_moe"),
ArchSpec("qwen3_next", "Qwen3-Next-80B-A3B-Thinking-4bit", "qwen3_next"),
ArchSpec("minimax", "MiniMax-M2.1-3bit", "minimax"),
ArchSpec("gpt_oss", "gpt-oss-20b-MXFP4-Q8", "gpt_oss"),
ArchSpec("step3p5", "Step-3.5-Flash-4bit", "step3p5"),
ArchSpec("kimi_k25", "Kimi-K2.5", "kimi_k25"),
]
def _arch_available(spec: ArchSpec) -> bool:
snap = _find_snapshot(spec.hub_name)
if snap is None:
return False
if spec.tokenizer_hub is not None:
return _find_snapshot(spec.tokenizer_hub) is not None
return True
def _make_task() -> TextGenerationTaskParams:
return TextGenerationTaskParams(
model=ModelId("test"),
input=[
InputMessage(
role="user",
content="Use the calculator to compute 1847 * 263 + 5921",
)
],
max_output_tokens=20,
temperature=0.0,
tools=[
{
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
}
],
)
# ── Test class ────────────────────────────────────────────────────────────── #
@pytest.mark.slow
class TestPrefixCacheArchitectures:
"""Verify prefix cache produces identical output to fresh generation for every architecture."""
@pytest.fixture(autouse=True)
def _cleanup(self):
yield
mx.clear_cache()
gc.collect()
@pytest.mark.parametrize(
"spec",
ARCHITECTURES,
ids=[a.name for a in ARCHITECTURES],
)
def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:
if not _arch_available(spec):
pytest.skip(f"Model {spec.hub_name} not cached locally")
snapshot = _find_snapshot(spec.hub_name)
assert snapshot is not None
tmpdir = Path(tempfile.mkdtemp(prefix=f"exo_test_{spec.name}_"))
try:
# Build reduced config
with open(snapshot / "config.json") as f:
cfg = cast(dict[str, Any], json.load(f))
reduced = _reduce_config(copy.deepcopy(cfg))
(tmpdir / "config.json").write_text(json.dumps(reduced))
# Copy tokenizer
tok_src = snapshot
if spec.tokenizer_hub is not None:
alt = _find_snapshot(spec.tokenizer_hub)
if alt is not None:
tok_src = alt
_copy_tokenizer(tok_src, tmpdir)
# Load tokenizer and model
model_id = ModelId(f"mlx-community/{spec.hub_name}")
tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)
mx.random.seed(0)
model = _build_model(spec.module, reduced)
task = _make_task()
prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)
# Run 1: fresh
mx.random.seed(42)
fresh = _collect_tokens(model, tokenizer, task, prompt, None)
assert len(fresh) > 0, "Fresh generation produced no tokens"
# Run 2: populate cache
kv = KVPrefixCache(None)
mx.random.seed(42)
populate = _collect_tokens(model, tokenizer, task, prompt, kv)
# Run 3: exact cache hit
mx.random.seed(42)
cached = _collect_tokens(model, tokenizer, task, prompt, kv)
assert fresh == populate, (
f"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}"
)
assert fresh == cached, (
f"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}"
)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

View File

@@ -343,8 +343,16 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
def contains(card: ModelCard, x: str):
return x in card.model_id.lower()
glm_model_cards = [
card for card in await get_model_cards() if "glm" in card.model_id.lower()
card
for card in await get_model_cards()
if contains(card, "glm")
and not contains(card, "-5")
and not contains(card, "4.7")
]
if not glm_model_cards:

View File

@@ -0,0 +1,162 @@
from collections.abc import Generator
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.runner import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
_CHANNEL = 200005 # <|channel|>
_START = 200006 # <|start|>
_MESSAGE = 200008 # <|message|>
_CALL = 200012 # <|call|>
_END = 200007 # <|end|>
_ASSISTANT = 173781 # "assistant"
# fmt: off
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_A_TOKENS: list[tuple[int, str]] = [
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_B_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
# Full analysis-then-tool-call as the model actually generates it.
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(35644, "analysis"),
(_MESSAGE, "<|message|>"),
(12845, "Let"),
(668, " me"),
(2411, " think"),
(1078, " about"),
(495, " this"),
(13, "."),
(_END, "<|end|>"),
# Model generates a new message header for the tool call:
(_START, "<|start|>"),
(_ASSISTANT, "assistant"),
*FORMAT_B_TOKENS,
]
# fmt: on
def _make_gen_responses(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse]:
"""Build GenerationResponse list from (token_id, text) pairs."""
responses: list[GenerationResponse] = []
for i, (tid, text) in enumerate(tokens):
is_last = i == len(tokens) - 1
responses.append(
GenerationResponse(
text=text,
token=tid,
finish_reason="stop" if is_last else None,
usage=None,
)
)
return responses
def _collect(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse | ToolCallResponse]:
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen()))
def _get_tool_call(
results: list[GenerationResponse | ToolCallResponse],
) -> ToolCallResponse:
"""Extract the single ToolCallResponse from results."""
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
return tool_calls[0]
class TestParseGptOssRecipientPlacement:
"""Both Harmony recipient placements must produce identical tool calls."""
def test_format_a_yields_tool_call(self):
results = _collect(FORMAT_A_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_format_b_yields_tool_call(self):
results = _collect(FORMAT_B_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_both_formats_produce_identical_tool_calls(self):
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
class TestParseGptOssThinkingThenToolCall:
"""Analysis (thinking) followed by a tool call must yield both."""
def test_thinking_then_tool_call(self):
results = _collect(THINKING_THEN_TOOL_TOKENS)
# Should have thinking tags + content + tool call
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
combined = "".join(text_parts)
assert "<think>" in combined
assert "</think>" in combined
assert "Let me think about this." in combined
# And the tool call
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert "Tokyo" in tc.tool_calls[0].arguments

55
tests/eval_tool_calls.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/usr/bin/env bash
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
sleep 1
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
done
wait
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
done
for host; do
echo "Waiting for $host..." 1>&2
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "Waiting 30s for cluster setup" 1>&2
sleep 30
echo "EXO loaded" 1>&2
eval_runner="${hosts[0]}"
mkdir -p "./bench/$commit"
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
echo "running eval for $model" 1>&2
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
>>"./bench/$commit/${model//\//--}-eval.json"
echo
done

2
uv.lock generated
View File

@@ -447,6 +447,7 @@ name = "exo-bench"
version = "0.1.0"
source = { editable = "bench" }
dependencies = [
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -456,6 +457,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "jinja2", specifier = ">=3.1.0" },
{ name = "loguru", specifier = ">=0.7.3" },