Compare commits

...

14 Commits

Author SHA1 Message Date
Alex Cheema
23f295e684 feat: show ETA on prefill progress bar
Track when prefill starts via performance.now() and extrapolate
remaining time from observed tokens/sec. Displays "~Xs remaining"
(or "~Xm Ys remaining" for longer prompts) next to the percentage.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:47 -08:00
Alex Cheema
7637fb554f refactor: address PR #1181 review comments from Evanev7
- Rename PrefillProgressData to PrefillProgressChunk for consistency
- Convert isinstance chain to match/case in collect_chat_response
- Remove unused StreamEvent type alias from chunks.py
- Update docstrings to reflect new naming

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:14 -08:00
Alex Cheema
1a9c5fa6fb fix: wire prefill progress callback to prefill stream_generate, not decode
- Move on_prefill_progress callback from decode stream_generate to prefill()
- Fix SSE parser to handle named event types (event: prefill_progress)
- Wire PrefillProgressBar component into ChatMessages
- Add prefillProgress reactive state to the store

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:30:18 -08:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
Mustafa Alp Yılmaz
2e29605194 fix: finalize cancel tasks (#1498)
# Cancel task finalization (main.py)

After forwarding the cancel to the runner supervisor, emit TaskStatusUpdated(Complete) for the cancel task itself. This ensures the cancel task is properly removed from state.tasks.
2026-02-19 13:27:34 +00:00
Evan Quiney
cacb456cb2 remove nightly (#1538)
we have no good need for rust nightly (nor futures, for that matter)
2026-02-19 12:55:31 +00:00
rltakashige
51021f6fc6 Add cancellation button and the ability to cancel during prefill (#1540)
## Motivation
There's no way to easily use the cancellation features we added! Also,
prefill can take ages so let's allow cancelling out of that.

## Changes

Wiring up our existing functionality to easily cancel during generation
(and adding stuff to do so during prefill)

## Test Plan

### Manual Testing
Tested it works during both prefill and decode.

### Automated testing
Needs testing to see if this causes a GPU timeout error on large prefill
on large models in pipeline parallel. However, from manually testing GLM
5 pipeline ring on 2 nodes, and from reading the code, it does not seem
like this will be the case.
2026-02-19 11:40:59 +00:00
Alex Cheema
025ed9fd82 feat: add prefill progress bar for long prompts (#1181)
## Motivation

Users processing long prompts have no visibility into when token
generation will start. This feature adds a progress bar showing prefill
progress, giving users real-time feedback during prompt processing.

## Changes

### Backend
- Added `PrefillProgress` event type with `command_id`,
`processed_tokens`, `total_tokens`
- Added `PrefillProgressResponse` type (though now using direct callback
approach)
- Wired `prompt_progress_callback` through MLX's `stream_generate()`
- Progress events sent directly from callback for real-time updates (not
batched)
- API generates SSE named events: `event: prefill_progress\ndata: {...}`
- Added `PrefillProgressData` dataclass and `StreamEvent` union type in
API

### Dashboard
- Added `PrefillProgress` interface to store
- Updated SSE parsing to handle `event:` lines (named events)
- Created `PrefillProgressBar.svelte` with animated progress bar
- Shows "Processing prompt: X/Y tokens" with percentage
- Progress bar disappears when first token arrives

## Why It Works

MLX's `stream_generate()` accepts a `prompt_progress_callback(processed,
total)` that's called after each prefill chunk. By sending events
directly from this callback (rather than yielding from the generator),
progress updates are sent in real-time during prefill.

Using SSE named events (`event: prefill_progress`) maintains full
OpenAI/Claude API compatibility - standard clients ignore named events
they don't recognize, while the exo dashboard explicitly listens for
them.

## Test Plan

### Manual Testing
- Hardware: MacBook Pro M3 Max
- Set `prefill_step_size=256` for more frequent updates
- Tested with long prompts (pasted large documents)
- Verified progress bar updates incrementally during prefill
- Confirmed progress bar disappears when generation starts
- Tested with curl - standard `data:` events still work normally

Here is it working:


https://github.com/user-attachments/assets/5cc6f075-c5b2-4a44-bb4d-9efb246bc5fe


### Automated Testing
- Type checker passes (0 errors)
- All 192 tests pass
- Dashboard builds successfully

### API Compatibility
- Named SSE events are ignored by OpenAI SDK clients
- Regular token data uses standard `data: {...}` format
- `[DONE]` sentinel works as expected

---

**Note:** `prefill_step_size` is temporarily set to 256 for testing.
Should be changed back to 2048 before merging for production
performance.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-19 03:18:25 +00:00
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
rltakashige
24e99ce197 Cleanup mistakes (#1537)
Oops
2026-02-18 22:05:26 +00:00
Alex Cheema
315992549b fix: unblock MpReceiver.close() to prevent shutdown hang (#1511)
## Summary

- `MpReceiver.close()` did not unblock threads stuck on `queue.get()` in
`receive_async()`, causing abandoned threads (via
`abandon_on_cancel=True`) to keep the Python process alive indefinitely
after tests pass
- This caused the `aarch64-darwin` CI jobs in PR #1462 to hang for ~6
hours until the GitHub Actions timeout killed them
- Sends an `_MpEndOfStream` sentinel before closing the buffer,
mirroring what `MpSender.close()` already does

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt` — 0 changed
- [x] `uv run pytest` — 188 passed, 1 skipped in 12s (no hang)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-18 21:59:02 +00:00
Alex Cheema
ce5a65d3b9 Add MiniMax M2.5 model cards (#1514)
## Summary
- Adds model cards for MiniMax M2.5 in three quantizations: 4bit (~129
GB), 6bit (~186 GB), 8bit (~243 GB)
- No code changes needed — `MiniMaxM2ForCausalLM` is already in the
tensor parallel whitelist and `MiniMaxShardingStrategy` is already
implemented in `auto_parallel.py`
- Credit to @vskiwi for confirming MiniMax M2.5 works out of the box
with existing code

Closes #1480

## Test plan
- [x] `basedpyright` passes with 0 errors
- [x] `ruff check` passes
- [x] `pytest` passes (260 passed, 1 skipped)
- [ ] Verify MiniMax M2.5 models appear in model selector on dashboard

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 21:11:13 +00:00
rltakashige
c2f2111b88 Fix tool calling (#1529)
## Motivation

GPT OSS tool calling issues.

## Changes

Fixes those and adds a bunch of evals for tool calling.
Fixes GLM5 prefix caching, where CacheList wasn't getting handled
properly.
Extracts a bunch of the setup functionality of exo bench to a harness
that can be reused elsewhere, such as in the tool calling eval.

## Test Plan
### Automated Testing
Let's run the evals for all models
2026-02-18 20:29:18 +00:00
60 changed files with 3137 additions and 1017 deletions

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict, strict: bool = ...) -> Module:
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -44,11 +47,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable,
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -80,9 +83,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -118,7 +121,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

13
Cargo.lock generated
View File

@@ -890,7 +890,7 @@ dependencies = [
"delegate",
"env_logger",
"extend",
"futures",
"futures-lite",
"libp2p",
"log",
"networking",
@@ -914,6 +914,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1022,7 +1028,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -2753,7 +2762,7 @@ dependencies = [
"delegate",
"either",
"extend",
"futures",
"futures-lite",
"futures-timer",
"keccak-const",
"libp2p",

View File

@@ -29,14 +29,13 @@ util = { path = "rust/util" }
# Macro dependecies
extend = "1.2"
delegate = "0.13"
pin-project = "1"
# Utility dependencies
keccak-const = "0.2"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-timer = "3.0"
# Data structures

View File

@@ -72,12 +72,19 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:
```bash
nix run .#exo
```
**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:
```
trusted-users = root (or your username)
experimental-features = nix-command flakes
```
Then restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)

1088
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,47 @@
# type: ignore
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +121,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +139,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +230,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +248,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +257,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +271,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +305,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,16 +332,6 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

327
bench/harness.py Normal file
View File

@@ -0,0 +1,327 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

306
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,306 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
[tools.create_todos]
description = "Create a structured todo list"
required = ["todos"]
[tools.create_todos.properties.todos]
type = "array"
description = "List of todo items"
[tools.create_todos.properties.todos.items]
type = "object"
required = ["content", "status", "priority"]
[tools.create_todos.properties.todos.items.properties.content]
type = "string"
description = "The todo item text"
[tools.create_todos.properties.todos.items.properties.status]
type = "string"
description = "Status: pending, in_progress, or completed"
[tools.create_todos.properties.todos.items.properties.priority]
type = "string"
description = "Priority: low, normal, or high"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Nested object schema (regression for lossy chat template rendering) --
[[scenarios]]
name = "nested_schema_tool_call"
description = "Tool call with nested object array schema -> create_todos"
expect_tool_call = true
expected_function = "create_todos"
required_arg_keys = ["todos"]
nested_array_key = "todos"
required_item_keys = ["content", "status", "priority"]
tools = ["create_todos"]
[[scenarios.messages]]
role = "user"
content = "Create a todo list with 3 items to learn Python"
# -- Tool name integrity (regression for harmony token leaking into name) --
[tools.glob]
description = "Search for files matching a glob pattern in the codebase"
required = ["pattern"]
[tools.glob.properties.pattern]
type = "string"
description = "The glob pattern to match files against, e.g. '**/*.py'"
[tools.glob.properties.path]
type = "string"
description = "The directory to search in"
[[scenarios]]
name = "tool_name_integrity"
description = "Tool name must not contain harmony tokens like <|channel|>"
expect_tool_call = true
expected_function = "glob"
required_arg_keys = ["pattern"]
tools = ["glob"]
[[scenarios.messages]]
role = "user"
content = "Find all Python files in the src directory"
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -14,6 +14,7 @@
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
stopGeneration,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -653,86 +654,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
fill="currentColor"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="6" y="6" width="12" height="12" rx="1" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">Cancel</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -3,16 +3,18 @@
messages,
currentResponse,
isLoading,
prefillProgress,
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
interface Props {
@@ -25,6 +27,7 @@
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
const prefill = $derived(prefillProgress());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
@@ -428,6 +431,9 @@
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
<PrefillProgressBar progress={prefill} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"
@@ -620,7 +626,9 @@
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
{#if loading && !message.content && prefill}
<PrefillProgressBar progress={prefill} class="mt-2" />
{:else if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>

View File

@@ -26,7 +26,8 @@
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
function formatNumber(num: number | undefined): string {
if (num == null) return "0";
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {

View File

@@ -0,0 +1,70 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -273,6 +273,13 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -515,12 +522,16 @@ class AppStore {
messages = $state<Message[]>([]);
currentResponse = $state("");
isLoading = $state(false);
prefillProgress = $state<PrefillProgress | null>(null);
// Performance metrics
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
// Topology state
topologyData = $state<TopologyData | null>(null);
instances = $state<Record<string, unknown>>({});
@@ -2005,9 +2016,11 @@ class AppStore {
reader: ReadableStreamDefaultReader<Uint8Array>,
targetConversationId: string,
onChunk: (parsed: T) => void,
onEvent?: Record<string, (data: unknown) => void>,
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
let currentEventType = "";
while (true) {
const { done, value } = await reader.read();
@@ -2023,18 +2036,52 @@ class AppStore {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed) {
currentEventType = "";
continue;
}
if (trimmed.startsWith("event: ")) {
currentEventType = trimmed.slice(7);
continue;
}
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
const comment = trimmed.slice(2);
const spaceIdx = comment.indexOf(" ");
if (spaceIdx > 0) {
const key = comment.slice(0, spaceIdx);
if (onEvent[key]) {
try {
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
onEvent[key](parsed);
} catch {
// Skip malformed JSON in comment
}
}
}
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
if (data === "[DONE]") {
currentEventType = "";
continue;
}
try {
const parsed = JSON.parse(data) as T;
onChunk(parsed);
const parsed = JSON.parse(data);
if (currentEventType && onEvent?.[currentEventType]) {
onEvent[currentEventType](parsed);
} else {
onChunk(parsed as T);
}
} catch {
// Skip malformed JSON
}
currentEventType = "";
}
}
}
@@ -2135,6 +2182,7 @@ class AppStore {
this.isLoading = true;
this.currentResponse = "";
this.prefillProgress = null;
this.ttftMs = null;
this.tps = null;
this.totalTokens = 0;
@@ -2256,6 +2304,9 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
const abortController = new AbortController();
this.currentAbortController = abortController;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -2272,6 +2323,7 @@ class AppStore {
enable_thinking: enableThinking,
}),
}),
signal: abortController.signal,
});
if (!response.ok) {
@@ -2309,6 +2361,11 @@ class AppStore {
reader,
targetConversationId,
(parsed) => {
// Clear prefill progress when first token data arrives
if (this.prefillProgress) {
this.prefillProgress = null;
}
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
@@ -2330,6 +2387,11 @@ class AppStore {
}
if (tokenContent) {
// Clear prefill progress once tokens start arriving
if (this.prefillProgress !== null) {
this.prefillProgress = null;
}
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -2371,8 +2433,27 @@ class AppStore {
this.persistConversation(targetConversationId);
}
},
{
prefill_progress: (data) => {
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
// model_dump_json() uses snake_case (by_alias defaults to False)
const raw = data as Record<string, unknown>;
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
processed_tokens: number;
total_tokens: number;
};
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
};
},
},
);
// Clear prefill progress after stream ends
this.prefillProgress = null;
// Calculate final TPS
if (firstTokenTime !== null && tokenCount > 1) {
const totalGenerationTime = performance.now() - firstTokenTime;
@@ -2403,20 +2484,32 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
if (error instanceof DOMException && error.name === "AbortError") {
// User stopped generation — not an error
} else {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
}
} finally {
this.currentAbortController = null;
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.prefillProgress = null;
this.saveConversationsToStorage();
}
}
stopGeneration(): void {
this.currentAbortController?.abort();
this.currentAbortController = null;
}
/**
* Generate an image using the image generation API
*/
@@ -3040,6 +3133,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
export const messages = () => appStore.messages;
export const currentResponse = () => appStore.currentResponse;
export const isLoading = () => appStore.isLoading;
export const prefillProgress = () => appStore.prefillProgress;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
@@ -3060,6 +3154,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const stopGeneration = () => appStore.stopGeneration();
export const startChat = () => appStore.startChat();
export const sendMessage = (
content: string,

View File

@@ -932,13 +932,6 @@
};
}
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log("[Download Debug] Current downloads:", downloadsData);
}
});
// Helper to get download status for an instance
function getInstanceDownloadStatus(
instanceId: string,

View File

@@ -74,7 +74,6 @@
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in

View File

@@ -158,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -27,7 +27,7 @@ networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
"nightly", # enables better-supported GIL integration
# "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
futures-lite = { workspace = true }
# utility dependencies
util = { workspace = true }
@@ -60,3 +59,4 @@ env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -2,7 +2,6 @@
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -26,8 +25,8 @@ where
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;

View File

@@ -4,25 +4,12 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
mod ident;
mod networking;
use crate::ident::ident_submodule;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -32,14 +19,6 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
@@ -180,7 +159,6 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,8 +8,8 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::{PyKeypair, PyPeerId};
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -22,7 +22,7 @@ delegate = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies

View File

@@ -1,4 +1,4 @@
use futures::stream::StreamExt as _;
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
@@ -38,19 +38,19 @@ async fn main() {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
event = swarm.next() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
@@ -64,7 +64,7 @@ async fn main() {
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::time::Duration;
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::ext::MultiaddrExt;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)

View File

@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;

View File

@@ -1,11 +1,10 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
{ inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
rustToolchain = inputs'.fenix.packages.stable.withComponents [
"cargo"
"rustc"
"clippy"

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -123,14 +123,17 @@ class DownloadCoordinator:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True

View File

@@ -19,7 +19,12 @@ from exo.shared.types.api import (
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -123,67 +128,81 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
match chunk:
case PrefillProgressChunk():
# Use SSE comment so third-party clients ignore it
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
case ToolCallChunk():
last_usage = chunk.usage or last_usage
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
last_usage = chunk.usage or last_usage
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(
update={"usage": last_usage}
)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -197,38 +216,43 @@ async def collect_chat_response(
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
match chunk:
case PrefillProgressChunk():
continue
if model is None:
model = chunk.model
case ErrorChunk():
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
case TokenChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
case ToolCallChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)

View File

@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -172,6 +179,9 @@ async def collect_claude_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -230,7 +240,9 @@ async def collect_claude_response(
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
@@ -256,6 +268,9 @@ async def generate_claude_stream(
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break

View File

@@ -5,7 +5,12 @@ from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
@@ -26,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -33,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -121,7 +132,9 @@ def responses_request_to_text_generation(
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -134,6 +147,9 @@ async def collect_responses_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -189,7 +205,9 @@ async def collect_responses_response(
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
@@ -207,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -224,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -235,7 +253,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -243,6 +261,9 @@ async def generate_responses_stream(
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
break
@@ -266,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -275,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -285,7 +306,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -300,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -316,7 +337,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -326,7 +347,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -337,7 +358,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -348,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -373,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -107,6 +107,7 @@ from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
@@ -137,6 +138,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -145,6 +147,7 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -220,7 +223,8 @@ class API:
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -526,19 +530,23 @@ class API:
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason is not None:
break
@@ -554,6 +562,8 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
@@ -565,6 +575,9 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -1292,8 +1305,18 @@ class API:
return total_available
async def get_models(self) -> ModelList:
"""Returns list of available models."""
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
"""Returns list of available models, optionally filtered by being downloaded."""
cards = await get_model_cards()
if status == "downloaded":
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [c for c in cards if c.model_id in downloaded_model_ids]
return ModelList(
data=[
ModelListModel(
@@ -1311,7 +1334,7 @@ class API:
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in cards
]
)
@@ -1435,6 +1458,21 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressChunk(
model=event.model,
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -64,6 +65,7 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressChunk(BaseChunk):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
GenerationChunk = (
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -102,6 +102,13 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
model: ModelId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -148,6 +155,7 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected

View File

@@ -4,10 +4,13 @@ from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]

View File

@@ -67,3 +67,8 @@ class ToolCallResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -1,3 +1,4 @@
import contextlib
import multiprocessing as mp
from dataclasses import dataclass, field
from math import inf
@@ -132,7 +133,8 @@ class MpSender[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
self._state.buffer.put(_MpEndOfStream())
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==
@@ -204,6 +206,8 @@ class MpReceiver[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
@@ -17,10 +18,22 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
@@ -64,7 +77,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None = None):
def __init__(self, group: mx.distributed.Group | None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
@@ -156,15 +169,15 @@ class KVPrefixCache:
best_length = 0
is_exact = False
# Find best cache
# Find best cache match
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(prompt_tokens, cached_prompt)
if length >= max_length - 1:
best_index, best_length = i, length
is_exact = True
break
if length > best_length:
best_index, best_length = i, length
if length == max_length:
is_exact = True
best_index, best_length = i, length
break
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
@@ -172,11 +185,12 @@ class KVPrefixCache:
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
target = (max_length - 1) if is_exact else best_length
has_ssm = has_non_kv_caches(self.caches[best_index])
target = (max_length - 1) if is_exact and not has_ssm else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
# No usable snapshot — need fresh cache
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
if restore_snap is None and has_ssm:
return make_kv_cache(model), prompt_tokens, None
prompt_cache = deepcopy(self.caches[best_index])
@@ -257,10 +271,21 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(prompt_tokens)
def _entry_length(
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
if hasattr(c, "offset"):
return c.offset
# For CacheList
if hasattr(c, "size"):
return int(c.size()) # type: ignore
return 0
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
return max(_entry_length(c) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:

View File

@@ -48,7 +48,11 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def prefill(
@@ -57,7 +61,8 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -65,7 +70,7 @@ def prefill(
then trims off the extra generated token.
Returns:
tokens_per_sec
(tokens_per_sec, num_tokens, snapshots)
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
@@ -76,6 +81,7 @@ def prefill(
has_ssm = has_non_kv_caches(cache)
snapshots: list[CacheSnapshot] = []
# TODO(evan): kill the callbacks/runner refactor
def progress_callback(processed: int, total: int) -> None:
elapsed = time.perf_counter() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
@@ -84,6 +90,11 @@ def prefill(
)
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
set_pipeline_prefill(model, is_prefill=True)
@@ -92,19 +103,23 @@ def prefill(
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=8192,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_prefill(model, is_prefill=False)
raise
set_pipeline_prefill(model, is_prefill=False)
@@ -133,7 +148,7 @@ def prefill(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
group: mx.distributed.Group | None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -255,8 +270,9 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -311,7 +327,13 @@ def mlx_generate(
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
group,
on_prefill_progress,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
@@ -436,9 +458,14 @@ def mlx_generate(
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index,

View File

@@ -1,5 +1,6 @@
import json
import os
import re
import sys
import time
from pathlib import Path
@@ -292,6 +293,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None
@@ -405,6 +408,56 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
func["arguments"] = json.loads(args)
def _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:
names: set[str] = set()
properties: dict[str, Any] = schema.get("properties", {}) # type: ignore[reportAny]
for prop_spec in properties.values(): # pyright: ignore[reportAny]
if not isinstance(prop_spec, dict):
continue
if prop_spec.get("type") == "array": # type: ignore[reportAny]
items: dict[str, Any] | None = prop_spec.get("items") # type: ignore[reportAny]
if isinstance(items, dict) and items.get("type") == "object": # type: ignore[reportAny]
inner_props: dict[str, Any] = items.get("properties", {}) # type: ignore[reportAny]
for k in inner_props: # pyright: ignore[reportUnknownVariableType]
names.add(str(k)) # pyright: ignore[reportUnknownArgumentType]
names.update(_collect_nested_property_names(items)) # pyright: ignore[reportUnknownArgumentType]
return names
def _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:
"""Return True if nested property names from any tool schema are absent."""
for tool in tools:
fn: dict[str, Any] = tool.get("function", {}) # type: ignore
params: dict[str, Any] = fn.get("parameters", {}) # type: ignore
nested = _collect_nested_property_names(params)
if nested and not all(name in prompt for name in nested):
return True
return False
_LOSSY_TEMPLATE_PATTERN = re.compile(
r"""inner_type\s*==\s*["']object \| object["']\s*or\s*inner_type\|length\s*>\s*\d+""",
)
def _patch_lossy_chat_template(template: str) -> str | None:
"""Patch chat templates that collapse nested object schemas to ``any[]``.
Some templates (e.g., GPT-OSS) have a guard like::
inner_type == "object | object" or inner_type|length > 50
The length check silently drops complex array-of-object schemas.
We remove the length guard, keeping only the object-union check.
Returns the patched template, or *None* if no patch was needed.
"""
patched, n = _LOSSY_TEMPLATE_PATTERN.subn(
lambda m: m.group(0).split(" or ")[0], # keep only the object-union check
template,
)
return patched if n > 0 else None
def apply_chat_template(
tokenizer: TokenizerWrapper,
task_params: TextGenerationTaskParams,
@@ -451,14 +504,28 @@ def apply_chat_template(
extra_kwargs["enable_thinking"] = task_params.enable_thinking
extra_kwargs["thinking"] = task_params.enable_thinking
patched_template: str | None = None
if task_params.tools:
original_template: str | None = getattr(tokenizer, "chat_template", None)
if isinstance(original_template, str):
patched_template = _patch_lossy_chat_template(original_template)
if patched_template is not None:
logger.info(
"Patched lossy chat template (removed inner_type length guard)"
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
**({"chat_template": patched_template} if patched_template is not None else {}),
**extra_kwargs,
)
if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):
logger.warning("Chat template lost nested tool schemas even after patching")
if partial_assistant_content:
prompt += partial_assistant_content

View File

@@ -241,6 +241,11 @@ class Worker:
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id

View File

@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -25,6 +26,7 @@ from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -80,7 +82,11 @@ from exo.worker.engines.image import (
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
mlx_generate,
warmup_inference,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
@@ -297,6 +303,32 @@ def main(
assert tokenizer
assert check_for_cancel_every
# Define callback to send prefill progress events
# and check for cancellation between prefill chunks.
# TODO(evan): kill the callbacks/runner refactor
# Specifically the part that this is literally duplicated code.
def on_prefill_progress(
processed: int,
total: int,
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
model=shard_metadata.model_card.model_id,
processed_tokens=processed,
total_tokens=total,
)
)
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, _group):
raise PrefillCancelled()
try:
_check_for_debug_prompts(task_params)
@@ -310,6 +342,7 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
group=group,
)
@@ -391,6 +424,8 @@ def main(
)
)
except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
@@ -588,17 +623,31 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(

View File

@@ -103,7 +103,7 @@ class RunnerSupervisor:
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(1)
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return

View File

@@ -123,7 +123,12 @@ def run_gpt_oss_pipeline_device(
generated_text = ""
for response in mlx_generate(
model=model, tokenizer=tokenizer, task=task, prompt=prompt
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:
@@ -194,6 +199,8 @@ def run_gpt_oss_tensor_parallel_device(
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache()
cache = KVPrefixCache(None)
cache.clear()
assert len(cache.prompts) == 0
@@ -142,7 +142,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
@@ -161,9 +163,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
@@ -176,9 +180,11 @@ class TestKVPrefixCacheWithModel:
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
# Exact match returns last token(s) — for models with SSM/rotating caches,
# snapshot availability constrains how far back we can trim, so remaining
# may be 1 or 2 tokens depending on the model.
assert len(remaining_tokens) >= 1
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
@@ -194,10 +200,10 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
@@ -238,9 +244,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -276,9 +284,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -301,7 +311,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -318,6 +328,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
generated_tokens += 1
@@ -331,7 +342,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -347,6 +358,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -368,7 +380,7 @@ class TestKVPrefixCacheWithModel:
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -395,6 +407,7 @@ class TestKVPrefixCacheWithModel:
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
first_gen_time = time.perf_counter() - t0
@@ -427,6 +440,7 @@ class TestKVPrefixCacheWithModel:
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
second_gen_time = time.perf_counter() - t0
@@ -447,7 +461,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -462,6 +476,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -474,6 +489,7 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -484,7 +500,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache = KVPrefixCache(None)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -497,7 +513,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -522,7 +538,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -0,0 +1,297 @@
import copy
import gc
import importlib
import json
import shutil
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
import pytest
from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
load_tokenizer_for_model_id,
)
HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
# ── Config reduction ──────────────────────────────────────────────────────── #
_REDUCE = {
"num_hidden_layers": 4,
"hidden_size": 256,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 512,
"moe_intermediate_size": 128,
"num_experts": 4,
"num_experts_per_tok": 2,
"n_routed_experts": 4,
"num_local_experts": 4,
"num_nextn_predict_layers": 0,
"first_k_dense_replace": 0,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"num_attention_groups": 4,
}
def _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:
result = dict(cfg)
for key, val in _REDUCE.items():
if key in result:
result[key] = val
return result
def _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:
result = _reduce_dict(cfg)
n_layers = cast(int, result.get("num_hidden_layers", 4))
if "text_config" in result and isinstance(result["text_config"], dict):
result["text_config"] = _reduce_dict(
cast(dict[str, Any], result["text_config"])
)
tc: dict[str, Any] = result["text_config"]
if "num_nextn_predict_layers" in tc:
tc["num_nextn_predict_layers"] = 0
if "layer_types" in result and isinstance(result["layer_types"], list):
result["layer_types"] = result["layer_types"][:n_layers]
if "attention_other_setting" in result and isinstance(
result["attention_other_setting"], dict
):
aos: dict[str, Any] = dict(
cast(dict[str, Any], result["attention_other_setting"])
)
if "num_attention_heads" in aos:
aos["num_attention_heads"] = result.get("num_attention_heads", 4)
if "num_attention_groups" in aos:
aos["num_attention_groups"] = result.get(
"num_attention_groups", cast(int, aos["num_attention_groups"])
)
result["attention_other_setting"] = aos
if "moe_layers_enum" in result and isinstance(result["moe_layers_enum"], str):
indices = [int(x) for x in result["moe_layers_enum"].split(",") if x.strip()]
valid = [i for i in indices if i < n_layers]
result["moe_layers_enum"] = ",".join(str(i) for i in valid) if valid else ""
return result
# ── Helpers ───────────────────────────────────────────────────────────────── #
def _find_snapshot(hub_name: str) -> Path | None:
model_dir = HF_CACHE / f"models--mlx-community--{hub_name}"
snaps = model_dir / "snapshots"
if not snaps.exists():
return None
children = sorted(snaps.iterdir())
return children[0] if children else None
def _copy_tokenizer(src: Path, dst: Path) -> None:
for f in src.iterdir():
name = f.name
if (
"tokeniz" in name.lower()
or "tiktoken" in name.lower()
or name.startswith("vocab")
or name.endswith(".jinja")
or "tool_declaration" in name
) and f.is_file():
shutil.copy2(f, dst / name)
def _build_model(module_name: str, cfg: dict[str, Any]) -> Model:
mod = importlib.import_module(f"mlx_lm.models.{module_name}")
args = mod.ModelArgs.from_dict(cfg) # pyright: ignore[reportAny]
model: nn.Module = mod.Model(args) # pyright: ignore[reportAny]
flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))
random_weights = [
(k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat
]
model.update(cast(dict[str, Any], tree_unflatten(random_weights)))
mx.eval(model.parameters())
return cast(Model, model)
def _collect_tokens(
model: Model,
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None,
) -> list[int]:
tokens: list[int] = []
for resp in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
tokens.append(resp.token)
if resp.finish_reason is not None:
break
return tokens
# ── Architecture definitions ──────────────────────────────────────────────── #
@dataclass(frozen=True)
class ArchSpec:
name: str
hub_name: str
module: str
tokenizer_hub: str | None = None # fallback for models without bundled tokenizer
ARCHITECTURES: list[ArchSpec] = [
ArchSpec("llama", "Llama-3.2-1B-Instruct-4bit", "llama"),
ArchSpec("glm_moe_dsa", "GLM-5-MXFP4-Q8", "glm_moe_dsa"),
ArchSpec(
"glm4_moe", "GLM-4.5-Air-8bit", "glm4_moe", tokenizer_hub="GLM-4.7-8bit-gs32"
),
ArchSpec(
"glm4_moe_lite",
"GLM-4.7-Flash-8bit",
"glm4_moe_lite",
tokenizer_hub="GLM-4.7-8bit-gs32",
),
ArchSpec("glm4_moe_47", "GLM-4.7-8bit-gs32", "glm4_moe"),
ArchSpec("qwen3", "Qwen3-4B-Instruct-2507-4bit", "qwen3"),
ArchSpec("qwen3_moe", "Qwen3-30B-A3B-4bit", "qwen3_moe"),
ArchSpec("qwen3_next", "Qwen3-Next-80B-A3B-Thinking-4bit", "qwen3_next"),
ArchSpec("minimax", "MiniMax-M2.1-3bit", "minimax"),
ArchSpec("gpt_oss", "gpt-oss-20b-MXFP4-Q8", "gpt_oss"),
ArchSpec("step3p5", "Step-3.5-Flash-4bit", "step3p5"),
ArchSpec("kimi_k25", "Kimi-K2.5", "kimi_k25"),
]
def _arch_available(spec: ArchSpec) -> bool:
snap = _find_snapshot(spec.hub_name)
if snap is None:
return False
if spec.tokenizer_hub is not None:
return _find_snapshot(spec.tokenizer_hub) is not None
return True
def _make_task() -> TextGenerationTaskParams:
return TextGenerationTaskParams(
model=ModelId("test"),
input=[
InputMessage(
role="user",
content="Use the calculator to compute 1847 * 263 + 5921",
)
],
max_output_tokens=20,
temperature=0.0,
tools=[
{
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
}
],
)
# ── Test class ────────────────────────────────────────────────────────────── #
@pytest.mark.slow
class TestPrefixCacheArchitectures:
"""Verify prefix cache produces identical output to fresh generation for every architecture."""
@pytest.fixture(autouse=True)
def _cleanup(self):
yield
mx.clear_cache()
gc.collect()
@pytest.mark.parametrize(
"spec",
ARCHITECTURES,
ids=[a.name for a in ARCHITECTURES],
)
def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:
if not _arch_available(spec):
pytest.skip(f"Model {spec.hub_name} not cached locally")
snapshot = _find_snapshot(spec.hub_name)
assert snapshot is not None
tmpdir = Path(tempfile.mkdtemp(prefix=f"exo_test_{spec.name}_"))
try:
# Build reduced config
with open(snapshot / "config.json") as f:
cfg = cast(dict[str, Any], json.load(f))
reduced = _reduce_config(copy.deepcopy(cfg))
(tmpdir / "config.json").write_text(json.dumps(reduced))
# Copy tokenizer
tok_src = snapshot
if spec.tokenizer_hub is not None:
alt = _find_snapshot(spec.tokenizer_hub)
if alt is not None:
tok_src = alt
_copy_tokenizer(tok_src, tmpdir)
# Load tokenizer and model
model_id = ModelId(f"mlx-community/{spec.hub_name}")
tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)
mx.random.seed(0)
model = _build_model(spec.module, reduced)
task = _make_task()
prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)
# Run 1: fresh
mx.random.seed(42)
fresh = _collect_tokens(model, tokenizer, task, prompt, None)
assert len(fresh) > 0, "Fresh generation produced no tokens"
# Run 2: populate cache
kv = KVPrefixCache(None)
mx.random.seed(42)
populate = _collect_tokens(model, tokenizer, task, prompt, kv)
# Run 3: exact cache hit
mx.random.seed(42)
cached = _collect_tokens(model, tokenizer, task, prompt, kv)
assert fresh == populate, (
f"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}"
)
assert fresh == cached, (
f"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}"
)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

View File

@@ -343,8 +343,16 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
def contains(card: ModelCard, x: str):
return x in card.model_id.lower()
glm_model_cards = [
card for card in await get_model_cards() if "glm" in card.model_id.lower()
card
for card in await get_model_cards()
if contains(card, "glm")
and not contains(card, "-5")
and not contains(card, "4.7")
]
if not glm_model_cards:

View File

@@ -0,0 +1,162 @@
from collections.abc import Generator
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.runner import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
_CHANNEL = 200005 # <|channel|>
_START = 200006 # <|start|>
_MESSAGE = 200008 # <|message|>
_CALL = 200012 # <|call|>
_END = 200007 # <|end|>
_ASSISTANT = 173781 # "assistant"
# fmt: off
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_A_TOKENS: list[tuple[int, str]] = [
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
FORMAT_B_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(12606, "comment"),
(815, "ary"),
(316, " to"),
(28, "="),
(44580, "functions"),
(775, ".get"),
(23981, "_current"),
(170154, "_weather"),
(5701, " json"),
(_MESSAGE, "<|message|>"),
(10848, '{"'),
(7693, "location"),
(1243, '":'),
(392, ' "'),
(173844, "Tokyo"),
(18583, '"}'),
(_CALL, "<|call|>"),
]
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
# Full analysis-then-tool-call as the model actually generates it.
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
(_CHANNEL, "<|channel|>"),
(35644, "analysis"),
(_MESSAGE, "<|message|>"),
(12845, "Let"),
(668, " me"),
(2411, " think"),
(1078, " about"),
(495, " this"),
(13, "."),
(_END, "<|end|>"),
# Model generates a new message header for the tool call:
(_START, "<|start|>"),
(_ASSISTANT, "assistant"),
*FORMAT_B_TOKENS,
]
# fmt: on
def _make_gen_responses(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse]:
"""Build GenerationResponse list from (token_id, text) pairs."""
responses: list[GenerationResponse] = []
for i, (tid, text) in enumerate(tokens):
is_last = i == len(tokens) - 1
responses.append(
GenerationResponse(
text=text,
token=tid,
finish_reason="stop" if is_last else None,
usage=None,
)
)
return responses
def _collect(
tokens: list[tuple[int, str]],
) -> list[GenerationResponse | ToolCallResponse]:
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen()))
def _get_tool_call(
results: list[GenerationResponse | ToolCallResponse],
) -> ToolCallResponse:
"""Extract the single ToolCallResponse from results."""
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
return tool_calls[0]
class TestParseGptOssRecipientPlacement:
"""Both Harmony recipient placements must produce identical tool calls."""
def test_format_a_yields_tool_call(self):
results = _collect(FORMAT_A_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_format_b_yields_tool_call(self):
results = _collect(FORMAT_B_TOKENS)
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert '"location"' in tc.tool_calls[0].arguments
assert "Tokyo" in tc.tool_calls[0].arguments
def test_both_formats_produce_identical_tool_calls(self):
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
class TestParseGptOssThinkingThenToolCall:
"""Analysis (thinking) followed by a tool call must yield both."""
def test_thinking_then_tool_call(self):
results = _collect(THINKING_THEN_TOOL_TOKENS)
# Should have thinking tags + content + tool call
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
combined = "".join(text_parts)
assert "<think>" in combined
assert "</think>" in combined
assert "Let me think about this." in combined
# And the tool call
tc = _get_tool_call(results)
assert tc.tool_calls[0].name == "get_current_weather"
assert "Tokyo" in tc.tool_calls[0].arguments

55
tests/eval_tool_calls.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/usr/bin/env bash
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
sleep 1
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
done
wait
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
done
for host; do
echo "Waiting for $host..." 1>&2
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "Waiting 30s for cluster setup" 1>&2
sleep 30
echo "EXO loaded" 1>&2
eval_runner="${hosts[0]}"
mkdir -p "./bench/$commit"
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
echo "running eval for $model" 1>&2
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
>>"./bench/$commit/${model//\//--}-eval.json"
echo
done

2
uv.lock generated
View File

@@ -447,6 +447,7 @@ name = "exo-bench"
version = "0.1.0"
source = { editable = "bench" }
dependencies = [
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -456,6 +457,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "jinja2", specifier = ">=3.1.0" },
{ name = "loguru", specifier = ">=0.7.3" },