mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
10 Commits
iroh
...
leo/add-ca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f2f3571cf | ||
|
|
efcb6b9393 | ||
|
|
72fd7f0e1f | ||
|
|
025ed9fd82 | ||
|
|
19bc09550d | ||
|
|
7cadca4f27 | ||
|
|
24e99ce197 | ||
|
|
315992549b | ||
|
|
ce5a65d3b9 | ||
|
|
c2f2111b88 |
@@ -200,7 +200,7 @@ class Module(dict):
|
||||
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def update(self, parameters: dict, strict: bool = ...) -> Module:
|
||||
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from mlx.core import MX_ARRAY_TREE
|
||||
|
||||
def tree_map(
|
||||
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
|
||||
fn: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -44,11 +47,11 @@ def tree_map(
|
||||
"""
|
||||
|
||||
def tree_map_with_path(
|
||||
fn: Callable,
|
||||
fn: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
path: Optional[Any] = ...,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
path: str | None = ...,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -80,9 +83,9 @@ def tree_map_with_path(
|
||||
def tree_flatten(
|
||||
tree: Any,
|
||||
prefix: str = ...,
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
|
||||
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
|
||||
) -> list[tuple[str, Any]] | dict[str, Any]:
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
@@ -118,7 +121,7 @@ def tree_flatten(
|
||||
the Python tree.
|
||||
"""
|
||||
|
||||
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
1088
bench/eval_tool_calls.py
Normal file
1088
bench/eval_tool_calls.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,47 @@
|
||||
# type: ignore
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||||
|
||||
Tests whether models correctly:
|
||||
- Trigger tool calls when appropriate
|
||||
- Return valid JSON arguments matching function schemas
|
||||
- Handle multi-turn tool use (call -> result -> final answer)
|
||||
- Avoid calling tools when unnecessary
|
||||
|
||||
Start exo with a model first, then run:
|
||||
uv run python tool_call_eval.py --model <model-id>
|
||||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from harness import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -103,154 +121,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
@@ -269,184 +139,6 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
@@ -538,76 +230,12 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
add_common_instance_args(ap)
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
@@ -620,34 +248,6 @@ def main() -> int:
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
@@ -657,9 +257,6 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
@@ -674,17 +271,6 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -719,24 +305,10 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
@@ -760,16 +332,6 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
327
bench/harness.py
Normal file
327
bench/harness.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
raise
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if (m.get("name") or "").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def settle_and_fetch_placements(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
args: argparse.Namespace,
|
||||
settle_timeout: float = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
@@ -4,6 +4,7 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
|
||||
306
bench/scenarios.toml
Normal file
306
bench/scenarios.toml
Normal file
@@ -0,0 +1,306 @@
|
||||
# Tool definitions — each becomes an OpenAI function tool.
|
||||
# All scenarios get all tools unless they specify a `tools` list.
|
||||
|
||||
[tools.get_current_weather]
|
||||
description = "Get the current weather in a given location"
|
||||
required = ["location"]
|
||||
|
||||
[tools.get_current_weather.properties.location]
|
||||
type = "string"
|
||||
description = "City and state, e.g. San Francisco, CA"
|
||||
|
||||
[tools.get_current_weather.properties.unit]
|
||||
type = "string"
|
||||
enum = ["celsius", "fahrenheit"]
|
||||
description = "Temperature unit"
|
||||
|
||||
[tools.calculate]
|
||||
description = "Evaluate a mathematical expression and return the numeric result"
|
||||
required = ["expression"]
|
||||
|
||||
[tools.calculate.properties.expression]
|
||||
type = "string"
|
||||
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
|
||||
|
||||
[tools.search_products]
|
||||
description = "Search for products in a catalog by query, category, and price"
|
||||
required = ["query"]
|
||||
|
||||
[tools.search_products.properties.query]
|
||||
type = "string"
|
||||
description = "Search query string"
|
||||
|
||||
[tools.search_products.properties.category]
|
||||
type = "string"
|
||||
enum = ["electronics", "clothing", "food", "books"]
|
||||
description = "Product category to filter by"
|
||||
|
||||
[tools.search_products.properties.max_price]
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
[tools.create_todos]
|
||||
description = "Create a structured todo list"
|
||||
required = ["todos"]
|
||||
|
||||
[tools.create_todos.properties.todos]
|
||||
type = "array"
|
||||
description = "List of todo items"
|
||||
|
||||
[tools.create_todos.properties.todos.items]
|
||||
type = "object"
|
||||
required = ["content", "status", "priority"]
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.content]
|
||||
type = "string"
|
||||
description = "The todo item text"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.status]
|
||||
type = "string"
|
||||
description = "Status: pending, in_progress, or completed"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.priority]
|
||||
type = "string"
|
||||
description = "Priority: low, normal, or high"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_simple"
|
||||
description = "Basic weather query -> get_current_weather"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather like in Tokyo right now?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_simple"
|
||||
description = "Math question -> calculate"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 3847 * 926 + 17293"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_with_filters"
|
||||
description = "Product search with category and price filter"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find me electronics under $50"
|
||||
|
||||
# -- Multi-turn: tool call then follow-up --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_multi_turn"
|
||||
description = "Weather query -> tool result -> natural language summary"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
temperature = "18C"
|
||||
condition = "partly cloudy"
|
||||
humidity = "65%"
|
||||
wind = "12 km/h NW"
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Paris?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_multi_turn"
|
||||
description = "Math query -> tool result -> model reports the answer"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
result = 491682
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 1847 * 263 + 5921"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_multi_turn"
|
||||
description = "Search query -> tool result -> model summarizes products"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Hands-On Machine Learning"
|
||||
price = 45.99
|
||||
rating = 4.8
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Deep Learning with Python"
|
||||
price = 39.99
|
||||
rating = 4.6
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Search for books about machine learning"
|
||||
|
||||
# -- Sequential tool calls --
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_same"
|
||||
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare the weather in Tokyo and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check both cities. Let me start with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_1"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_1"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_different"
|
||||
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll handle both. Let me check Berlin's weather first."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_2"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Berlin" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_2"
|
||||
content = '{"temperature": "12C", "condition": "rainy"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_three"
|
||||
description = "Two prior thinking+tool calls -> results -> model must make a third"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare weather in Tokyo, Paris, and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check all three cities. Starting with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_3"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_3"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "Got Tokyo. Now checking Paris."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_4"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Paris" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Nested object schema (regression for lossy chat template rendering) --
|
||||
|
||||
[[scenarios]]
|
||||
name = "nested_schema_tool_call"
|
||||
description = "Tool call with nested object array schema -> create_todos"
|
||||
expect_tool_call = true
|
||||
expected_function = "create_todos"
|
||||
required_arg_keys = ["todos"]
|
||||
nested_array_key = "todos"
|
||||
required_item_keys = ["content", "status", "priority"]
|
||||
tools = ["create_todos"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Create a todo list with 3 items to learn Python"
|
||||
|
||||
# -- Tool name integrity (regression for harmony token leaking into name) --
|
||||
|
||||
[tools.glob]
|
||||
description = "Search for files matching a glob pattern in the codebase"
|
||||
required = ["pattern"]
|
||||
|
||||
[tools.glob.properties.pattern]
|
||||
type = "string"
|
||||
description = "The glob pattern to match files against, e.g. '**/*.py'"
|
||||
|
||||
[tools.glob.properties.path]
|
||||
type = "string"
|
||||
description = "The directory to search in"
|
||||
|
||||
[[scenarios]]
|
||||
name = "tool_name_integrity"
|
||||
description = "Tool name must not contain harmony tokens like <|channel|>"
|
||||
expect_tool_call = true
|
||||
expected_function = "glob"
|
||||
required_arg_keys = ["pattern"]
|
||||
tools = ["glob"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find all Python files in the src directory"
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_joke"
|
||||
description = "Joke request should NOT trigger any tool"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Tell me a funny joke about cats."
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_factual"
|
||||
description = "Factual question answerable from training data"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What is the capital of Japan?"
|
||||
@@ -14,6 +14,7 @@
|
||||
totalTokens,
|
||||
thinkingEnabled as thinkingEnabledStore,
|
||||
setConversationThinking,
|
||||
stopGeneration,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -653,86 +654,92 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
<rect x="6" y="6" width="12" height="12" rx="1" />
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">Cancel</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -3,16 +3,17 @@
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
prefillProgress,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
import PrefillProgressBar from "./PrefillProgressBar.svelte";
|
||||
import ImageLightbox from "./ImageLightbox.svelte";
|
||||
|
||||
interface Props {
|
||||
@@ -25,6 +26,7 @@
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
const prefill = $derived(prefillProgress());
|
||||
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
@@ -428,6 +430,9 @@
|
||||
{:else}
|
||||
<!-- Assistant message styling -->
|
||||
<div class="p-3 sm:p-4">
|
||||
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
|
||||
<PrefillProgressBar progress={prefill} class="mb-3" />
|
||||
{/if}
|
||||
{#if message.thinking && message.thinking.trim().length > 0}
|
||||
<div
|
||||
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"
|
||||
|
||||
@@ -26,7 +26,8 @@
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
function formatNumber(num: number | undefined): string {
|
||||
if (num == null) return "0";
|
||||
if (num >= 1000000) {
|
||||
return `${(num / 1000000).toFixed(1)}M`;
|
||||
} else if (num >= 1000) {
|
||||
|
||||
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
@@ -0,0 +1,52 @@
|
||||
<script lang="ts">
|
||||
import type { PrefillProgress } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
progress: PrefillProgress;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { progress, class: className = "" }: Props = $props();
|
||||
|
||||
const percentage = $derived(
|
||||
progress.total > 0
|
||||
? Math.round((progress.processed / progress.total) * 100)
|
||||
: 0,
|
||||
);
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
return `${(count / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="prefill-progress {className}">
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
|
||||
>
|
||||
<span>Processing prompt</span>
|
||||
<span class="font-mono">
|
||||
{formatTokenCount(progress.processed)} / {formatTokenCount(
|
||||
progress.total,
|
||||
)} tokens
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.prefill-progress {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
@@ -273,6 +273,11 @@ export interface TokenData {
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -520,6 +525,10 @@ class AppStore {
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
tps = $state<number | null>(null); // Tokens per second
|
||||
totalTokens = $state<number>(0); // Total tokens in current response
|
||||
prefillProgress = $state<PrefillProgress | null>(null);
|
||||
|
||||
// Abort controller for stopping generation
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Topology state
|
||||
topologyData = $state<TopologyData | null>(null);
|
||||
@@ -2005,6 +2014,7 @@ class AppStore {
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>,
|
||||
targetConversationId: string,
|
||||
onChunk: (parsed: T) => void,
|
||||
onEvent?: Record<string, (data: unknown) => void>,
|
||||
): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
@@ -2025,6 +2035,24 @@ class AppStore {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
|
||||
// Handle SSE comments (": key json") for prefill progress etc.
|
||||
if (trimmed.startsWith(": ") && onEvent) {
|
||||
const comment = trimmed.slice(2);
|
||||
const spaceIdx = comment.indexOf(" ");
|
||||
if (spaceIdx > 0) {
|
||||
const key = comment.slice(0, spaceIdx);
|
||||
if (onEvent[key]) {
|
||||
try {
|
||||
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
|
||||
onEvent[key](parsed);
|
||||
} catch {
|
||||
// Skip malformed JSON in comment
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
@@ -2256,6 +2284,9 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
const abortController = new AbortController();
|
||||
this.currentAbortController = abortController;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -2272,6 +2303,7 @@ class AppStore {
|
||||
enable_thinking: enableThinking,
|
||||
}),
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2309,6 +2341,11 @@ class AppStore {
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
// Clear prefill progress when first token data arrives
|
||||
if (this.prefillProgress) {
|
||||
this.prefillProgress = null;
|
||||
}
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
@@ -2371,8 +2408,26 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
{
|
||||
prefill_progress: (data) => {
|
||||
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
|
||||
// model_dump_json() uses snake_case (by_alias defaults to False)
|
||||
const raw = data as Record<string, unknown>;
|
||||
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
|
||||
processed_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
};
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Clear prefill progress after stream ends
|
||||
this.prefillProgress = null;
|
||||
|
||||
// Calculate final TPS
|
||||
if (firstTokenTime !== null && tokenCount > 1) {
|
||||
const totalGenerationTime = performance.now() - firstTokenTime;
|
||||
@@ -2403,20 +2458,31 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
"Failed to get response",
|
||||
);
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
// User stopped generation — not an error
|
||||
} else {
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
"Failed to get response",
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.prefillProgress = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
stopGeneration(): void {
|
||||
this.currentAbortController?.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image using the image generation API
|
||||
*/
|
||||
@@ -3043,6 +3109,7 @@ export const isLoading = () => appStore.isLoading;
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const runners = () => appStore.runners;
|
||||
@@ -3060,6 +3127,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
|
||||
// Actions
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const startChat = () => appStore.startChat();
|
||||
export const sendMessage = (
|
||||
content: string,
|
||||
|
||||
@@ -932,13 +932,6 @@
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log("[Download Debug] Current downloads:", downloadsData);
|
||||
}
|
||||
});
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(
|
||||
instanceId: string,
|
||||
|
||||
@@ -158,6 +158,7 @@
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "4bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 128666664960
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-6bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "6bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 185826705408
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -123,14 +123,17 @@ class DownloadCoordinator:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
|
||||
self.shard_downloader.set_internet_connection(True)
|
||||
except OSError:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
logger.debug(
|
||||
f"Internet connectivity: {self.shard_downloader.internet_connection}"
|
||||
)
|
||||
# Try multiple endpoints since some ISPs/networks block specific IPs
|
||||
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
|
||||
try:
|
||||
socket.create_connection((host, 443), timeout=3).close()
|
||||
self.shard_downloader.set_internet_connection(True)
|
||||
logger.debug(f"Internet connectivity: True (via {host})")
|
||||
return
|
||||
except OSError:
|
||||
continue
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
logger.debug("Internet connectivity: False")
|
||||
|
||||
async def _check_internet_connection(self) -> None:
|
||||
first_connection = True
|
||||
|
||||
@@ -19,7 +19,12 @@ from exo.shared.types.api import (
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
@@ -123,67 +128,81 @@ def chunk_to_response(
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
# Use SSE comment so third-party clients ignore it
|
||||
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
case ErrorChunk():
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
case ToolCallChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(
|
||||
update={"usage": last_usage}
|
||||
)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -197,38 +216,43 @@ async def collect_chat_response(
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
case ErrorChunk():
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
|
||||
async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -172,6 +179,9 @@ async def collect_claude_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -230,7 +240,9 @@ async def collect_claude_response(
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
@@ -256,6 +268,9 @@ async def generate_claude_stream(
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
@@ -5,7 +5,12 @@ from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
FunctionCallInputItem,
|
||||
@@ -121,7 +126,9 @@ def responses_request_to_text_generation(
|
||||
async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -134,6 +141,9 @@ async def collect_responses_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -189,7 +199,9 @@ async def collect_responses_response(
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
@@ -243,6 +255,9 @@ async def generate_responses_stream(
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
@@ -137,6 +138,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -145,6 +147,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -220,7 +223,8 @@ class API:
|
||||
)
|
||||
|
||||
self._text_generation_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -526,19 +530,23 @@ class API:
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
) -> AsyncGenerator[
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
|
||||
]:
|
||||
"""Yield chunks for a given command until completion.
|
||||
|
||||
This is the internal low-level stream used by all API adapters.
|
||||
"""
|
||||
try:
|
||||
self._text_generation_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -565,6 +573,9 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._token_chunk_stream(command_id):
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -1292,8 +1303,18 @@ class API:
|
||||
|
||||
return total_available
|
||||
|
||||
async def get_models(self) -> ModelList:
|
||||
"""Returns list of available models."""
|
||||
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
|
||||
"""Returns list of available models, optionally filtered by being downloaded."""
|
||||
cards = await get_model_cards()
|
||||
|
||||
if status == "downloaded":
|
||||
downloaded_model_ids: set[str] = set()
|
||||
for node_downloads in self.state.downloads.values():
|
||||
for dl in node_downloads:
|
||||
if isinstance(dl, DownloadCompleted):
|
||||
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||
cards = [c for c in cards if c.model_id in downloaded_model_ids]
|
||||
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelListModel(
|
||||
@@ -1311,7 +1332,7 @@ class API:
|
||||
base_model=card.base_model,
|
||||
capabilities=card.capabilities,
|
||||
)
|
||||
for card in await get_model_cards()
|
||||
for card in cards
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1435,6 +1456,21 @@ class API:
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -64,6 +65,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
class PrefillProgressChunk(BaseChunk):
|
||||
"""Data class for prefill progress events during streaming."""
|
||||
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
GenerationChunk = (
|
||||
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,6 +102,13 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -148,6 +155,7 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -4,10 +4,13 @@ from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
|
||||
@@ -67,3 +67,8 @@ class ToolCallResponse(BaseRunnerResponse):
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass, field
|
||||
from math import inf
|
||||
@@ -132,7 +133,8 @@ class MpSender[T]:
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
self._state.closed.set()
|
||||
self._state.buffer.put(_MpEndOfStream())
|
||||
with contextlib.suppress(Exception):
|
||||
self._state.buffer.put_nowait(_MpEndOfStream())
|
||||
self._state.buffer.close()
|
||||
|
||||
# == unique to Mp channels ==
|
||||
@@ -204,6 +206,8 @@ class MpReceiver[T]:
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
self._state.closed.set()
|
||||
with contextlib.suppress(Exception):
|
||||
self._state.buffer.put_nowait(_MpEndOfStream())
|
||||
self._state.buffer.close()
|
||||
|
||||
# == unique to Mp channels ==
|
||||
|
||||
@@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
@@ -17,10 +18,22 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
)
|
||||
|
||||
|
||||
@@ -64,7 +77,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
def __init__(self, group: mx.distributed.Group | None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
@@ -156,15 +169,15 @@ class KVPrefixCache:
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
# Find best cache match
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length >= max_length - 1:
|
||||
best_index, best_length = i, length
|
||||
is_exact = True
|
||||
break
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
@@ -172,11 +185,12 @@ class KVPrefixCache:
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
has_ssm = has_non_kv_caches(self.caches[best_index])
|
||||
target = (max_length - 1) if is_exact and not has_ssm else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
if restore_snap is None and has_ssm:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
@@ -257,10 +271,21 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def _entry_length(
|
||||
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
if hasattr(c, "offset"):
|
||||
return c.offset
|
||||
# For CacheList
|
||||
if hasattr(c, "size"):
|
||||
return int(c.size()) # type: ignore
|
||||
return 0
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
return max(_entry_length(c) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
|
||||
@@ -48,7 +48,11 @@ from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
|
||||
|
||||
|
||||
class PrefillCancelled(BaseException):
|
||||
"""Raised when prefill is cancelled via the progress callback."""
|
||||
|
||||
|
||||
def prefill(
|
||||
@@ -57,7 +61,8 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None = None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -65,7 +70,7 @@ def prefill(
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
(tokens_per_sec, num_tokens, snapshots)
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
@@ -76,6 +81,7 @@ def prefill(
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
# TODO(evan): kill the callbacks/runner refactor
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
@@ -85,6 +91,9 @@ def prefill(
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
@@ -92,19 +101,23 @@ def prefill(
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=8192,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
try:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=4096,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
|
||||
@@ -133,7 +146,7 @@ def prefill(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -255,8 +268,9 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -311,7 +325,13 @@ def mlx_generate(
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
group,
|
||||
on_prefill_progress,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
@@ -436,9 +456,14 @@ def mlx_generate(
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
hit_ratio = (
|
||||
prefix_hit_length / len(all_prompt_tokens)
|
||||
if len(all_prompt_tokens) > 0
|
||||
else 0.0
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -292,6 +293,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
elif "gpt-oss" in model_id_lower:
|
||||
return [200002, 200012]
|
||||
return None
|
||||
|
||||
|
||||
@@ -405,6 +408,56 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:
|
||||
names: set[str] = set()
|
||||
properties: dict[str, Any] = schema.get("properties", {}) # type: ignore[reportAny]
|
||||
for prop_spec in properties.values(): # pyright: ignore[reportAny]
|
||||
if not isinstance(prop_spec, dict):
|
||||
continue
|
||||
if prop_spec.get("type") == "array": # type: ignore[reportAny]
|
||||
items: dict[str, Any] | None = prop_spec.get("items") # type: ignore[reportAny]
|
||||
if isinstance(items, dict) and items.get("type") == "object": # type: ignore[reportAny]
|
||||
inner_props: dict[str, Any] = items.get("properties", {}) # type: ignore[reportAny]
|
||||
for k in inner_props: # pyright: ignore[reportUnknownVariableType]
|
||||
names.add(str(k)) # pyright: ignore[reportUnknownArgumentType]
|
||||
names.update(_collect_nested_property_names(items)) # pyright: ignore[reportUnknownArgumentType]
|
||||
return names
|
||||
|
||||
|
||||
def _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if nested property names from any tool schema are absent."""
|
||||
for tool in tools:
|
||||
fn: dict[str, Any] = tool.get("function", {}) # type: ignore
|
||||
params: dict[str, Any] = fn.get("parameters", {}) # type: ignore
|
||||
nested = _collect_nested_property_names(params)
|
||||
if nested and not all(name in prompt for name in nested):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_LOSSY_TEMPLATE_PATTERN = re.compile(
|
||||
r"""inner_type\s*==\s*["']object \| object["']\s*or\s*inner_type\|length\s*>\s*\d+""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
"""Patch chat templates that collapse nested object schemas to ``any[]``.
|
||||
|
||||
Some templates (e.g., GPT-OSS) have a guard like::
|
||||
|
||||
inner_type == "object | object" or inner_type|length > 50
|
||||
|
||||
The length check silently drops complex array-of-object schemas.
|
||||
We remove the length guard, keeping only the object-union check.
|
||||
Returns the patched template, or *None* if no patch was needed.
|
||||
"""
|
||||
patched, n = _LOSSY_TEMPLATE_PATTERN.subn(
|
||||
lambda m: m.group(0).split(" or ")[0], # keep only the object-union check
|
||||
template,
|
||||
)
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -451,14 +504,28 @@ def apply_chat_template(
|
||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||
|
||||
patched_template: str | None = None
|
||||
if task_params.tools:
|
||||
original_template: str | None = getattr(tokenizer, "chat_template", None)
|
||||
if isinstance(original_template, str):
|
||||
patched_template = _patch_lossy_chat_template(original_template)
|
||||
if patched_template is not None:
|
||||
logger.info(
|
||||
"Patched lossy chat template (removed inner_type length guard)"
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=task_params.tools,
|
||||
**({"chat_template": patched_template} if patched_template is not None else {}),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):
|
||||
logger.warning("Chat template lost nested tool schemas even after patching")
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
@@ -25,6 +26,7 @@ from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -80,7 +82,11 @@ from exo.worker.engines.image import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
PrefillCancelled,
|
||||
mlx_generate,
|
||||
warmup_inference,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -297,6 +303,32 @@ def main(
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
# Define callback to send prefill progress events
|
||||
# and check for cancellation between prefill chunks.
|
||||
# TODO(evan): kill the callbacks/runner refactor
|
||||
# Specifically the part that this is literally duplicated code.
|
||||
def on_prefill_progress(
|
||||
processed: int,
|
||||
total: int,
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, _group):
|
||||
raise PrefillCancelled()
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
@@ -310,6 +342,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@@ -391,6 +424,8 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
except PrefillCancelled:
|
||||
logger.info(f"Prefill cancelled for task {task.task_id}")
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
@@ -588,17 +623,31 @@ def parse_gpt_oss(
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
stream.process(response.token)
|
||||
try:
|
||||
stream.process(response.token)
|
||||
except HarmonyError:
|
||||
logger.error("Encountered critical Harmony Error, returning early")
|
||||
return
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
# Debug: log every token with state
|
||||
logger.debug(
|
||||
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||
)
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
logger.info(
|
||||
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||
)
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
|
||||
@@ -103,7 +103,7 @@ class RunnerSupervisor:
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
@@ -123,7 +123,12 @@ def run_gpt_oss_pipeline_device(
|
||||
generated_text = ""
|
||||
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -194,6 +199,8 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(None)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,7 +142,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
@@ -161,9 +163,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
@@ -176,9 +180,11 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
# Exact match returns last token(s) — for models with SSM/rotating caches,
|
||||
# snapshot availability constrains how far back we can trim, so remaining
|
||||
# may be 1 or 2 tokens depending on the model.
|
||||
assert len(remaining_tokens) >= 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
@@ -194,10 +200,10 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
@@ -238,9 +244,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -276,9 +284,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -301,7 +311,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -318,6 +328,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
@@ -331,7 +342,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -347,6 +358,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -368,7 +380,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -395,6 +407,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
@@ -427,6 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
@@ -447,7 +461,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -462,6 +476,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -474,6 +489,7 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -484,7 +500,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -497,7 +513,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -522,7 +538,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
import copy
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
load_tokenizer_for_model_id,
|
||||
)
|
||||
|
||||
HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
|
||||
|
||||
# ── Config reduction ──────────────────────────────────────────────────────── #
|
||||
|
||||
_REDUCE = {
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 256,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 512,
|
||||
"moe_intermediate_size": 128,
|
||||
"num_experts": 4,
|
||||
"num_experts_per_tok": 2,
|
||||
"n_routed_experts": 4,
|
||||
"num_local_experts": 4,
|
||||
"num_nextn_predict_layers": 0,
|
||||
"first_k_dense_replace": 0,
|
||||
"linear_num_key_heads": 2,
|
||||
"linear_num_value_heads": 2,
|
||||
"num_attention_groups": 4,
|
||||
}
|
||||
|
||||
|
||||
def _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
result = dict(cfg)
|
||||
for key, val in _REDUCE.items():
|
||||
if key in result:
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
result = _reduce_dict(cfg)
|
||||
n_layers = cast(int, result.get("num_hidden_layers", 4))
|
||||
|
||||
if "text_config" in result and isinstance(result["text_config"], dict):
|
||||
result["text_config"] = _reduce_dict(
|
||||
cast(dict[str, Any], result["text_config"])
|
||||
)
|
||||
tc: dict[str, Any] = result["text_config"]
|
||||
if "num_nextn_predict_layers" in tc:
|
||||
tc["num_nextn_predict_layers"] = 0
|
||||
|
||||
if "layer_types" in result and isinstance(result["layer_types"], list):
|
||||
result["layer_types"] = result["layer_types"][:n_layers]
|
||||
|
||||
if "attention_other_setting" in result and isinstance(
|
||||
result["attention_other_setting"], dict
|
||||
):
|
||||
aos: dict[str, Any] = dict(
|
||||
cast(dict[str, Any], result["attention_other_setting"])
|
||||
)
|
||||
if "num_attention_heads" in aos:
|
||||
aos["num_attention_heads"] = result.get("num_attention_heads", 4)
|
||||
if "num_attention_groups" in aos:
|
||||
aos["num_attention_groups"] = result.get(
|
||||
"num_attention_groups", cast(int, aos["num_attention_groups"])
|
||||
)
|
||||
result["attention_other_setting"] = aos
|
||||
|
||||
if "moe_layers_enum" in result and isinstance(result["moe_layers_enum"], str):
|
||||
indices = [int(x) for x in result["moe_layers_enum"].split(",") if x.strip()]
|
||||
valid = [i for i in indices if i < n_layers]
|
||||
result["moe_layers_enum"] = ",".join(str(i) for i in valid) if valid else ""
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────── #
|
||||
|
||||
|
||||
def _find_snapshot(hub_name: str) -> Path | None:
|
||||
model_dir = HF_CACHE / f"models--mlx-community--{hub_name}"
|
||||
snaps = model_dir / "snapshots"
|
||||
if not snaps.exists():
|
||||
return None
|
||||
children = sorted(snaps.iterdir())
|
||||
return children[0] if children else None
|
||||
|
||||
|
||||
def _copy_tokenizer(src: Path, dst: Path) -> None:
|
||||
for f in src.iterdir():
|
||||
name = f.name
|
||||
if (
|
||||
"tokeniz" in name.lower()
|
||||
or "tiktoken" in name.lower()
|
||||
or name.startswith("vocab")
|
||||
or name.endswith(".jinja")
|
||||
or "tool_declaration" in name
|
||||
) and f.is_file():
|
||||
shutil.copy2(f, dst / name)
|
||||
|
||||
|
||||
def _build_model(module_name: str, cfg: dict[str, Any]) -> Model:
|
||||
mod = importlib.import_module(f"mlx_lm.models.{module_name}")
|
||||
args = mod.ModelArgs.from_dict(cfg) # pyright: ignore[reportAny]
|
||||
model: nn.Module = mod.Model(args) # pyright: ignore[reportAny]
|
||||
flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))
|
||||
random_weights = [
|
||||
(k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat
|
||||
]
|
||||
model.update(cast(dict[str, Any], tree_unflatten(random_weights)))
|
||||
mx.eval(model.parameters())
|
||||
return cast(Model, model)
|
||||
|
||||
|
||||
def _collect_tokens(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
) -> list[int]:
|
||||
tokens: list[int] = []
|
||||
for resp in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
tokens.append(resp.token)
|
||||
if resp.finish_reason is not None:
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Architecture definitions ──────────────────────────────────────────────── #
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchSpec:
|
||||
name: str
|
||||
hub_name: str
|
||||
module: str
|
||||
tokenizer_hub: str | None = None # fallback for models without bundled tokenizer
|
||||
|
||||
|
||||
ARCHITECTURES: list[ArchSpec] = [
|
||||
ArchSpec("llama", "Llama-3.2-1B-Instruct-4bit", "llama"),
|
||||
ArchSpec("glm_moe_dsa", "GLM-5-MXFP4-Q8", "glm_moe_dsa"),
|
||||
ArchSpec(
|
||||
"glm4_moe", "GLM-4.5-Air-8bit", "glm4_moe", tokenizer_hub="GLM-4.7-8bit-gs32"
|
||||
),
|
||||
ArchSpec(
|
||||
"glm4_moe_lite",
|
||||
"GLM-4.7-Flash-8bit",
|
||||
"glm4_moe_lite",
|
||||
tokenizer_hub="GLM-4.7-8bit-gs32",
|
||||
),
|
||||
ArchSpec("glm4_moe_47", "GLM-4.7-8bit-gs32", "glm4_moe"),
|
||||
ArchSpec("qwen3", "Qwen3-4B-Instruct-2507-4bit", "qwen3"),
|
||||
ArchSpec("qwen3_moe", "Qwen3-30B-A3B-4bit", "qwen3_moe"),
|
||||
ArchSpec("qwen3_next", "Qwen3-Next-80B-A3B-Thinking-4bit", "qwen3_next"),
|
||||
ArchSpec("minimax", "MiniMax-M2.1-3bit", "minimax"),
|
||||
ArchSpec("gpt_oss", "gpt-oss-20b-MXFP4-Q8", "gpt_oss"),
|
||||
ArchSpec("step3p5", "Step-3.5-Flash-4bit", "step3p5"),
|
||||
ArchSpec("kimi_k25", "Kimi-K2.5", "kimi_k25"),
|
||||
]
|
||||
|
||||
|
||||
def _arch_available(spec: ArchSpec) -> bool:
|
||||
snap = _find_snapshot(spec.hub_name)
|
||||
if snap is None:
|
||||
return False
|
||||
if spec.tokenizer_hub is not None:
|
||||
return _find_snapshot(spec.tokenizer_hub) is not None
|
||||
return True
|
||||
|
||||
|
||||
def _make_task() -> TextGenerationTaskParams:
|
||||
return TextGenerationTaskParams(
|
||||
model=ModelId("test"),
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content="Use the calculator to compute 1847 * 263 + 5921",
|
||||
)
|
||||
],
|
||||
max_output_tokens=20,
|
||||
temperature=0.0,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Evaluate a mathematical expression",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Test class ────────────────────────────────────────────────────────────── #
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestPrefixCacheArchitectures:
|
||||
"""Verify prefix cache produces identical output to fresh generation for every architecture."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup(self):
|
||||
yield
|
||||
mx.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
ARCHITECTURES,
|
||||
ids=[a.name for a in ARCHITECTURES],
|
||||
)
|
||||
def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:
|
||||
if not _arch_available(spec):
|
||||
pytest.skip(f"Model {spec.hub_name} not cached locally")
|
||||
|
||||
snapshot = _find_snapshot(spec.hub_name)
|
||||
assert snapshot is not None
|
||||
|
||||
tmpdir = Path(tempfile.mkdtemp(prefix=f"exo_test_{spec.name}_"))
|
||||
try:
|
||||
# Build reduced config
|
||||
with open(snapshot / "config.json") as f:
|
||||
cfg = cast(dict[str, Any], json.load(f))
|
||||
reduced = _reduce_config(copy.deepcopy(cfg))
|
||||
(tmpdir / "config.json").write_text(json.dumps(reduced))
|
||||
|
||||
# Copy tokenizer
|
||||
tok_src = snapshot
|
||||
if spec.tokenizer_hub is not None:
|
||||
alt = _find_snapshot(spec.tokenizer_hub)
|
||||
if alt is not None:
|
||||
tok_src = alt
|
||||
_copy_tokenizer(tok_src, tmpdir)
|
||||
|
||||
# Load tokenizer and model
|
||||
model_id = ModelId(f"mlx-community/{spec.hub_name}")
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)
|
||||
mx.random.seed(0)
|
||||
model = _build_model(spec.module, reduced)
|
||||
|
||||
task = _make_task()
|
||||
prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)
|
||||
|
||||
# Run 1: fresh
|
||||
mx.random.seed(42)
|
||||
fresh = _collect_tokens(model, tokenizer, task, prompt, None)
|
||||
assert len(fresh) > 0, "Fresh generation produced no tokens"
|
||||
|
||||
# Run 2: populate cache
|
||||
kv = KVPrefixCache(None)
|
||||
mx.random.seed(42)
|
||||
populate = _collect_tokens(model, tokenizer, task, prompt, kv)
|
||||
|
||||
# Run 3: exact cache hit
|
||||
mx.random.seed(42)
|
||||
cached = _collect_tokens(model, tokenizer, task, prompt, kv)
|
||||
|
||||
assert fresh == populate, (
|
||||
f"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}"
|
||||
)
|
||||
assert fresh == cached, (
|
||||
f"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
@@ -343,8 +343,16 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
|
||||
def contains(card: ModelCard, x: str):
|
||||
return x in card.model_id.lower()
|
||||
|
||||
glm_model_cards = [
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
card
|
||||
for card in await get_model_cards()
|
||||
if contains(card, "glm")
|
||||
and not contains(card, "-5")
|
||||
and not contains(card, "4.7")
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
162
src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py
Normal file
162
src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_gpt_oss
|
||||
|
||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||
# These are stable since they come from the model's vocabulary.
|
||||
_CHANNEL = 200005 # <|channel|>
|
||||
_START = 200006 # <|start|>
|
||||
_MESSAGE = 200008 # <|message|>
|
||||
_CALL = 200012 # <|call|>
|
||||
_END = 200007 # <|end|>
|
||||
_ASSISTANT = 173781 # "assistant"
|
||||
|
||||
# fmt: off
|
||||
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_A_TOKENS: list[tuple[int, str]] = [
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_B_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
|
||||
# Full analysis-then-tool-call as the model actually generates it.
|
||||
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(35644, "analysis"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(12845, "Let"),
|
||||
(668, " me"),
|
||||
(2411, " think"),
|
||||
(1078, " about"),
|
||||
(495, " this"),
|
||||
(13, "."),
|
||||
(_END, "<|end|>"),
|
||||
# Model generates a new message header for the tool call:
|
||||
(_START, "<|start|>"),
|
||||
(_ASSISTANT, "assistant"),
|
||||
*FORMAT_B_TOKENS,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _make_gen_responses(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse]:
|
||||
"""Build GenerationResponse list from (token_id, text) pairs."""
|
||||
responses: list[GenerationResponse] = []
|
||||
for i, (tid, text) in enumerate(tokens):
|
||||
is_last = i == len(tokens) - 1
|
||||
responses.append(
|
||||
GenerationResponse(
|
||||
text=text,
|
||||
token=tid,
|
||||
finish_reason="stop" if is_last else None,
|
||||
usage=None,
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse | ToolCallResponse]:
|
||||
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
|
||||
|
||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||
yield from _make_gen_responses(tokens)
|
||||
|
||||
return list(parse_gpt_oss(_gen()))
|
||||
|
||||
|
||||
def _get_tool_call(
|
||||
results: list[GenerationResponse | ToolCallResponse],
|
||||
) -> ToolCallResponse:
|
||||
"""Extract the single ToolCallResponse from results."""
|
||||
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
class TestParseGptOssRecipientPlacement:
|
||||
"""Both Harmony recipient placements must produce identical tool calls."""
|
||||
|
||||
def test_format_a_yields_tool_call(self):
|
||||
results = _collect(FORMAT_A_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_format_b_yields_tool_call(self):
|
||||
results = _collect(FORMAT_B_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_both_formats_produce_identical_tool_calls(self):
|
||||
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
|
||||
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
|
||||
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
|
||||
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
|
||||
|
||||
|
||||
class TestParseGptOssThinkingThenToolCall:
|
||||
"""Analysis (thinking) followed by a tool call must yield both."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
55
tests/eval_tool_calls.sh
Executable file
55
tests/eval_tool_calls.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
eval_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
|
||||
echo "running eval for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
|
||||
>>"./bench/$commit/${model//\//--}-eval.json"
|
||||
echo
|
||||
done
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -447,6 +447,7 @@ name = "exo-bench"
|
||||
version = "0.1.0"
|
||||
source = { editable = "bench" }
|
||||
dependencies = [
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -456,6 +457,7 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
|
||||
Reference in New Issue
Block a user