Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
45a9f32fc7 woahg 2026-02-17 11:44:23 +00:00
105 changed files with 1868 additions and 5300 deletions

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
def update(self, parameters: dict, strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,10 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -47,11 +44,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable[..., Any],
fn: Callable,
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -83,9 +80,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -121,7 +118,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

View File

@@ -1,46 +0,0 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

44
Cargo.lock generated
View File

@@ -673,17 +673,6 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "delegate"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "der"
version = "0.7.10"
@@ -887,31 +876,16 @@ dependencies = [
name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"env_logger",
"extend",
"futures-lite",
"libp2p",
"log",
"networking",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"tokio",
"util",
]
[[package]]
name = "extend"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "311a6d2f1f9d60bff73d2c78a0af97ed27f79672f15c238192a5bbb64db56d00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
@@ -1747,12 +1721,6 @@ dependencies = [
"cpufeatures",
]
[[package]]
name = "keccak-const"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d8d8ce877200136358e0bbff3a77965875db3af755a11e1fa6b1b3e2df13ea"
[[package]]
name = "lalrpop-util"
version = "0.20.2"
@@ -2759,18 +2727,10 @@ dependencies = [
name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"either",
"extend",
"futures-lite",
"futures-timer",
"keccak-const",
"libp2p",
"log",
"pin-project",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
@@ -4600,10 +4560,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "util"
version = "0.0.1"
[[package]]
name = "uuid"
version = "1.19.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/util",
]
[workspace.package]
@@ -24,33 +23,18 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Macro dependecies
extend = "1.2"
delegate = "0.13"
# Utility dependencies
keccak-const = "0.2"
# Async dependencies
tokio = "1.46"
futures-lite = "2.6.1"
futures-timer = "3.0"
# Data structures
either = "1.15"
# Tracing/logging
log = "0.4"
# networking
libp2p = "0.56"
libp2p-tcp = "0.44"
[workspace.lints.rust]
static_mut_refs = "warn" # Or use "warn" instead of deny
incomplete_features = "allow"
static_mut_refs = "warn"
# Clippy's lint category level configurations;
# every member crate needs to inherit these by adding
@@ -71,64 +55,3 @@ perf = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
nursery = { level = "warn", priority = -1 }
cargo = { level = "warn", priority = -1 }
# Individual Clippy lints from the `restriction` category
arithmetic_side_effects = "warn"
as_conversions = "warn"
assertions_on_result_states = "warn"
clone_on_ref_ptr = "warn"
decimal_literal_representation = "warn"
default_union_representation = "warn"
deref_by_slicing = "warn"
disallowed_script_idents = "deny"
else_if_without_else = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exit = "deny"
expect_used = "warn"
float_cmp_const = "warn"
get_unwrap = "warn"
if_then_some_else_none = "warn"
impl_trait_in_params = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
let_underscore_must_use = "warn"
let_underscore_untyped = "warn"
lossy_float_literal = "warn"
mem_forget = "warn"
missing_inline_in_public_items = "warn"
multiple_inherent_impl = "warn"
multiple_unsafe_ops_per_block = "warn"
mutex_atomic = "warn"
non_zero_suggestions = "warn"
panic = "warn"
partial_pub_fields = "warn"
pattern_type_mismatch = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "warn"
redundant_type_annotations = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
same_name_method = "warn"
self_named_module_files = "deny"
semicolon_inside_block = "warn"
shadow_same = "warn"
shadow_unrelated = "warn"
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_to_string = "warn"
tests_outside_test_module = "warn"
todo = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unneeded_field_pattern = "warn"
unseparated_literal_suffix = "warn"
unused_result_ok = "warn"
unused_trait_names = "warn"
unwrap_used = "warn"
verbose_file_reads = "warn"

View File

@@ -1,5 +1,5 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Log namespace on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```
@@ -206,14 +199,14 @@ The app will ask for permission to modify system settings and install a new Netw
**Custom Namespace for Cluster Isolation:**
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `--namespace` cli arg:
- **Use cases**:
- Running multiple separate exo clusters on the same network
- Isolating development/testing clusters from production clusters
- Preventing accidental cluster joining
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
- **Configuration**: Access this setting in the app's Advanced settings (or set the `--namespace` argument when running from source)
The namespace is logged on startup for debugging purposes.
@@ -425,4 +418,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -82,6 +82,7 @@ final class ExoProcessController: ObservableObject {
let child = Process()
child.executableURL = executableURL
child.arguments = ["--namespace", computeNamespace()]
let exoHomeURL = Self.exoDirectoryURL
try? FileManager.default.createDirectory(
at: exoHomeURL, withIntermediateDirectories: true
@@ -126,37 +127,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {
@@ -242,7 +217,6 @@ final class ExoProcessController: ObservableObject {
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
if !hfToken.isEmpty {
environment["HF_TOKEN"] = hfToken
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,47 +1,29 @@
# type: ignore
#!/usr/bin/env python3
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -121,6 +103,154 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -139,6 +269,184 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -230,12 +538,76 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
add_common_instance_args(ap)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--pp",
nargs="+",
@@ -248,6 +620,34 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -257,6 +657,9 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -271,6 +674,17 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -305,10 +719,24 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -332,6 +760,16 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

View File

@@ -1,327 +0,0 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)

View File

@@ -4,7 +4,6 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

View File

@@ -1,240 +0,0 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -103,7 +103,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking_toggle") && caps.includes("text");
return caps.includes("thinking") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(

View File

@@ -59,14 +59,13 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1536",
"1536x1024",
"1024x1365",
"1365x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -177,90 +176,92 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{/if}
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -310,7 +311,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1536"
| "1536x1024";
| "1024x1365"
| "1365x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -337,7 +336,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "auto",
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -74,6 +74,7 @@
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
@@ -114,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260218+14841977"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.7",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
@@ -133,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;
@@ -158,7 +132,6 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -26,11 +26,11 @@ networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
"abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"experimental-async" # async support in #[pyfunction] & #[pymethods]
# "experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
# "py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
@@ -42,21 +42,18 @@ pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures-lite = { workspace = true }
# utility dependencies
util = { workspace = true }
futures-lite = "2.6.1"
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -2,119 +2,39 @@
# ruff: noqa: E501, F401
import builtins
import enum
import typing
@typing.final
class AllQueuesFullError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
Generate a new ed25519 keypair
"""
@staticmethod
def from_bytes(bytes: bytes) -> Keypair:
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Construct an Ed25519 keypair from secret key bytes
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
def to_bytes(self) -> bytes:
def to_protobuf_encoding(self) -> bytes:
r"""
Get the secret key bytes underlying the keypair
"""
def to_node_id(self) -> builtins.str:
r"""
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
Encode a private key to a protobuf structure.
"""
def to_string(self) -> builtins.str: ...
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
class PyPeer:
@staticmethod
def new(kp: Keypair, namespace: builtins.str) -> PyPeer: ...
async def subscribe(self, topic: builtins.str) -> None: ...
async def unsubscribe(self, topic: builtins.str) -> None: ...
async def send(self, topic: builtins.str, payload: bytes) -> None: ...
async def run(self) -> None: ...
async def recv(self) -> PySwarmEvent: ...
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...
class PySwarmEvent:
def downcast_discovered(self) -> typing.Optional[builtins.str]: ...
def downcast_expired(self) -> typing.Optional[builtins.str]: ...
def downcast_message(self) -> typing.Optional[tuple[builtins.str, builtins.str, bytes]]: ...

View File

@@ -1,37 +1,22 @@
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
use pyo3::prelude::*;
use std::{
future::Future,
pin::Pin,
pin::{Pin, pin},
task::{Context, Poll},
};
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
pub struct AllowThreads<F>(pub(crate) F);
impl<F> Future for AllowThreads<F>
where
F: Future + Send,
F: Future + Unpin + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -1,47 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.expect("we only use ed25519 keys")
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}

View File

@@ -1,173 +1,42 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
pub(crate) mod allow_threading;
mod allow_threading;
mod ident;
mod networking;
pub(crate) mod networking;
pub(crate) mod take_once {
use std::sync::Mutex;
pub struct TakeOnce<T>(Mutex<Option<T>>);
impl<T> TakeOnce<T> {
pub fn new(t: T) -> Self {
Self(Mutex::new(Some(t)))
}
pub fn take(&self) -> Option<T> {
match self.0.try_lock() {
Ok(mut o) => o.take(),
Err(_) => None,
}
}
}
}
use pyo3::prelude::*;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule;
use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::attach(|py| PyBytes::new(py, self).unbind())
}
}
#[ext(pub, name = ResultExt)]
impl<T, E> Result<T, E>
where
E: ToString,
{
fn pyerr(self) -> PyResult<T> {
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::attach(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule(name = "exo_pyo3_bindings")]
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
pub fn networking_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// setup runtime
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
pyo3_async_runtimes::tokio::init(builder);
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
m.add_class::<PyKeypair>()?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
m.add_class::<networking::PyPeer>()?;
m.add_class::<networking::PyKeypair>()?;
Ok(())
}

View File

@@ -1,291 +1,214 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
use crate::pyclass;
use futures_lite::StreamExt as _;
use libp2p::gossipsub::PublishError;
use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, Python, pymethods};
use pyo3_stub_gen::derive::{
gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pyclass_enum, gen_stub_pymethods,
use crate::allow_threading::AllowThreads;
use crate::take_once::TakeOnce;
use std::pin::pin;
use futures_lite::FutureExt;
use libp2p::{gossipsub::PublishError, identity::Keypair};
use networking::{FromSwarm, Peer, ToSwarm};
use pyo3::{
coroutine::CancelHandle,
exceptions::{PyConnectionError, PyRuntimeError, PyValueError},
prelude::*,
types::PyBytes,
};
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
pub struct PyNoPeersSubscribedToTopicError {}
impl PyNoPeersSubscribedToTopicError {
const MSG: &'static str = "\
No peers are currently subscribed to receive messages on this topic. \
Wait for peers to subscribe or check your network connectivity.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
pub struct PyAllQueuesFullError {}
impl PyAllQueuesFullError {
const MSG: &'static str =
"All libp2p peers are unresponsive, resend the message or reconnect.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
use pyo3_stub_gen::{
derive::{gen_methods_from_python, gen_stub_pyclass, gen_stub_pymethods},
inventory::submit,
};
use tokio::sync::{Mutex, mpsc};
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
#[pyclass(name = "Keypair", frozen)]
#[derive(Clone)]
pub struct PyKeypair(Keypair);
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: String,
#[gen_stub_pymethods]
#[pymethods]
impl PyKeypair {
/// Generate a new ed25519 keypair
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: &Bound<'_, PyBytes>) -> Self {
let bytes = Vec::from(bytes.as_bytes());
Self(Keypair::from_protobuf_encoding(&bytes).expect("todo"))
}
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
/// Encode a private key to a protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
match self.0.to_protobuf_encoding() {
Ok(bytes) => Ok(PyBytes::new(py, &bytes)),
Err(e) => Err(PyValueError::new_err(e.to_string())),
}
}
fn to_string(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}
struct PeerBuilder(
String,
Keypair,
mpsc::Sender<FromSwarm>,
mpsc::Receiver<ToSwarm>,
);
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
struct PyNetworkingHandle {
// channels
pub to_swarm: mpsc::Sender<ToSwarm>,
pub swarm: Mutex<Swarm>,
}
#[gen_stub_pyclass_complex_enum]
#[pyclass]
enum PyFromSwarm {
Connection {
peer_id: String,
connected: bool,
},
Message {
origin: String,
topic: String,
data: Py<PyBytes>,
},
}
impl From<FromSwarm> for PyFromSwarm {
fn from(value: FromSwarm) -> Self {
match value {
FromSwarm::Discovered { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: true,
},
FromSwarm::Expired { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: false,
},
FromSwarm::Message { from, topic, data } => Self::Message {
origin: from.to_base58(),
topic: topic,
data: data.pybytes(),
},
}
}
pub struct PyPeer {
peer: TakeOnce<PeerBuilder>,
to_swarm: mpsc::Sender<ToSwarm>,
from_swarm: Mutex<mpsc::Receiver<FromSwarm>>,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
// create communication channels
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
let swarm = { create_swarm(identity, from_client).pyerr()? };
impl PyPeer {
#[staticmethod]
fn new(kp: PyKeypair, namespace: String) -> PyResult<Self> {
let (to_client, from_swarm) = mpsc::channel(1024);
let (to_swarm, from_client) = mpsc::channel(1024);
Ok(Self {
swarm: Mutex::new(swarm),
peer: TakeOnce::new(PeerBuilder(namespace, kp.0, to_client, from_client)),
to_swarm,
from_swarm: Mutex::new(from_swarm),
})
}
async fn recv(&self) -> PyResult<PyFromSwarm> {
self.swarm
.try_lock()
.expect("tried to recv from swarm twice concurrently")
.next()
.allow_threads_py()
.await
.ok_or(PyErr::receiver_channel_closed())
.map(Into::into)
#[gen_stub(skip)]
async fn run(&self, #[pyo3(cancel_handle)] mut cancel: CancelHandle) -> PyResult<()> {
let builder = self
.peer
.take()
.ok_or_else(|| PyRuntimeError::new_err("tried to run peer twice"))?;
let jh = pyo3_async_runtimes::tokio::get_runtime()
.spawn(async move {
let mut peer =
Peer::new(builder.0, builder.1, builder.2, builder.3).map_err(|_| {
PyConnectionError::new_err("peer failed to listen on default address")
})?;
peer.run()
.await
.map_err(|()| PyConnectionError::new_err("peer communication closed"))
})
.or(async {
cancel.cancelled().await;
Ok(Ok(()))
});
match AllowThreads(pin!(jh)).await {
Err(e) if e.is_cancelled() => Ok(()),
Err(e) if e.is_panic() => Err(PyRuntimeError::new_err(format!("tokio panic {e}"))),
Err(_) => unreachable!(),
Ok(res) => res,
}
}
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
async fn subscribe(&self, topic: String) -> PyResult<()> {
self.to_swarm
.send_py(ToSwarm::Subscribe {
topic,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.send(ToSwarm::Subscribe(topic))
.await
.map_err(|_| PyErr::receiver_channel_closed())?
.pyerr()
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
async fn unsubscribe(&self, topic: String) -> PyResult<()> {
self.to_swarm
.send(ToSwarm::Unsubscribe(topic))
.await
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
async fn send(&self, topic: String, payload: Py<PyBytes>) -> PyResult<()> {
// this function attaches to the python interpreter synchronously to avoid holding the GIL
let bytes = Python::attach(|py| Vec::from(payload.bind(py).as_bytes()));
self.to_swarm
.send(ToSwarm::Message(topic, bytes))
.await
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_swarm
.send_py(ToSwarm::Unsubscribe {
topic,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
#[gen_stub(skip)]
async fn recv(
&self,
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
) -> PyResult<PySwarmEvent> {
loop {
return match AllowThreads(pin!(
self.from_swarm
.try_lock()
.map_err(|_| PyRuntimeError::new_err("tried to recv twice"))?
.recv()
.or(async {
cancel.cancelled().await;
None
})
))
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_swarm
.send_py(ToSwarm::Publish {
topic,
data,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
.map_err(|e| match e {
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
PublishError::MessageTooLarge => PyNoPeersSubscribedToTopicError::new_err(),
e => PyRuntimeError::new_err(e.to_string()),
})?;
Ok(())
{
Some(FromSwarm::PublishError(p)) => match p {
PublishError::AllQueuesFull(_) => {
Err(PyConnectionError::new_err("swarm overloaded"))
}
PublishError::MessageTooLarge => {
Err(PyValueError::new_err("message too large"))
}
PublishError::NoPeersSubscribedToTopic => {
continue;
}
// TODO(evan): logs here
_ => continue,
},
None => Err(PyRuntimeError::new_err("swarm communication closed")),
Some(fs) => Ok(PySwarmEvent(fs)),
};
}
}
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
// Manually submit the run()/recv() stub because the cancelhandle is poorly understood
submit! {
gen_methods_from_python! {
r#"
class PyPeer:
async def run(self): ...
async def recv(self) -> PySwarmEvent: ...
"#
}
}
#[gen_stub_pyclass]
#[pyclass]
pub struct PySwarmEvent(FromSwarm);
#[gen_stub_pymethods]
#[pymethods]
impl PySwarmEvent {
// probably a better way to do this, but...
fn downcast_discovered(&self) -> Option<String> {
if let FromSwarm::Discovered(peer_id) = self.0 {
Some(peer_id.to_base58())
} else {
None
}
}
fn downcast_expired(&self) -> Option<String> {
if let FromSwarm::Expired(peer_id) = self.0 {
Some(peer_id.to_base58())
} else {
None
}
}
fn downcast_message<'py>(
&self,
py: Python<'py>,
) -> Option<(String, String, Bound<'py, PyBytes>)> {
if let FromSwarm::Message(peer_id, topic, data) = &self.0 {
Some((peer_id.to_base58(), topic.clone(), PyBytes::new(py, data)))
} else {
None
}
}
}

View File

@@ -1,54 +0,0 @@
#[cfg(test)]
mod tests {
use core::mem::drop;
use core::option::Option::Some;
use core::time::Duration;
use tokio;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_drop_channel() {
struct Ping;
let (tx, mut rx) = mpsc::channel::<Ping>(10);
let _ = tokio::spawn(async move {
println!("TASK: entered");
loop {
tokio::select! {
result = rx.recv() => {
match result {
Some(_) => {
println!("TASK: pinged");
}
None => {
println!("TASK: closing channel");
break;
}
}
}
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
println!("TASK: heartbeat");
}
}
}
println!("TASK: exited");
});
let tx2 = tx.clone();
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx.send(Ping).await.expect("Should not fail");
drop(tx);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx2.send(Ping).await.expect("Should not fail");
drop(tx2);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
}
}

View File

@@ -13,26 +13,14 @@ path = "src/lib.rs"
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -1,8 +1,6 @@
use futures_lite::StreamExt;
use libp2p::identity;
use networking::swarm;
use networking::swarm::{FromSwarm, ToSwarm};
use tokio::sync::{mpsc, oneshot};
use networking::{self, FromSwarm, ToSwarm};
use tokio::sync::mpsc;
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -13,50 +11,52 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (to_swarm, from_client) = mpsc::channel(20);
// Configure swarm
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
.expect("Swarm creation failed");
let (to_client, mut from_swarm) = mpsc::channel(20);
let (to_swarm, from_client) = mpsc::channel(20);
let mut peer = networking::Peer::new(
"chatroom!".to_string(),
identity::Keypair::generate_ed25519(),
to_client,
from_client,
)
.expect("listen error");
// Create a Gossipsub topic & subscribe
let (tx, rx) = oneshot::channel();
_ = to_swarm
.send(ToSwarm::Subscribe {
topic: "test-net".to_string(),
result_sender: tx,
})
.await
.expect("should send");
let mut fused = futures_lite::future::fuse(rx);
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
let jh = tokio::spawn(async move { peer.run().await });
_ = to_swarm
.send(ToSwarm::Subscribe("chatting".to_string()))
.await;
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
let (tx, rx) = oneshot::channel();
if let Err(e) = to_swarm.send(swarm::ToSwarm::Publish { topic: "test-net".to_string(), data: line.as_bytes().to_vec(), result_sender: tx }).await {
println!("Send error: {e:?}");
return
};
if let Err(e) = rx.await {
println!("Publish error: {e:?}");
}
},
event = swarm.next() => match event {
_ = to_swarm.send(ToSwarm::Message("chatting".to_string(), line.into_bytes())).await;
}
event = from_swarm.recv() => match event {
// on gossipsub incoming
Some(FromSwarm::Discovered { peer_id }) => { println!("\n\nconnected to {peer_id}\n\n") },
Some(FromSwarm::Expired { peer_id }) => { println!("\n\ndisconnected from {peer_id}\n\n") },
Some(FromSwarm::Message { from, topic, data }) => { println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data)) },
None => {},
},
f = &mut fused => {assert!(f.expect("should recv").expect("should subscribe"))},
Some(FromSwarm::Message(peer_id,_, data)) => println!(
"\n\nGot message: '{}' from peer: {peer_id}\n\n",
String::from_utf8_lossy(&data),
),
// on discovery
Some(FromSwarm::Discovered(peer_id)) => {
println!("\n\nConnected to: {peer_id}\n\n");
}
Some(FromSwarm::Expired(peer_id)) => {
println!("\n\nDisconnected from: {peer_id}\n\n");
}
Some(FromSwarm::PublishError(e)) => eprintln!("\n\nError {e:?}\n\n"),
None => break,
}
}
}
_ = jh.await;
}

View File

@@ -1,44 +0,0 @@
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
https://stackoverflow.com/questions/17169298/af-packet-on-osx
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
https://www.unix.com/man_page/mojave/4/ip/
^OS X manpages for IP
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
^driver kit, system extensions & kexts for macOS
----
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
----> here is a guide on how to set up thunderbolt-ethernet on linux
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
^GPT discussion about making socket interface
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
^GPT discussion about accessing TB on MacOS low level interactions
--------------------------------
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
---------------------------------
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c

View File

@@ -1,382 +0,0 @@
use crate::ext::MultiaddrExt;
use delegate::delegate;
use either::Either;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -1,44 +1,299 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
pub mod discovery;
pub mod swarm;
use libp2p::{
Multiaddr, PeerId,
futures::StreamExt,
gossipsub::{self, TopicHash},
identify,
identity::Keypair,
mdns,
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
};
use std::collections::HashMap;
use tokio::sync::mpsc;
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
#[derive(Debug)]
pub struct ListenError;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
pub enum FromSwarm {
PublishError(gossipsub::PublishError),
Discovered(PeerId),
Expired(PeerId),
Message(PeerId, String, Vec<u8>),
}
pub enum ToSwarm {
Message(String, Vec<u8>),
Subscribe(String),
Unsubscribe(String),
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use extend::ext;
use libp2p::Multiaddr;
use libp2p::multiaddr::Protocol;
use std::net::IpAddr;
pub struct Peer {
pub swarm: libp2p::Swarm<Behaviour>,
to_client: mpsc::Sender<FromSwarm>,
from_client: mpsc::Receiver<ToSwarm>,
namespace: String,
known_peers: HashMap<PeerId, Vec<Multiaddr>>,
}
impl Peer {
pub fn new(
namespace: String,
kp: Keypair,
to_client: mpsc::Sender<FromSwarm>,
from_client: mpsc::Receiver<ToSwarm>,
) -> Result<Self, ListenError> {
let mut swarm = libp2p::SwarmBuilder::with_existing_identity(kp)
.with_tokio()
.with_quic()
// TODO(evan) .with_bandwidth_metrics()
.with_behaviour(|kp| Behaviour::new(namespace.clone(), kp))
.expect("invalid swarm behaviour")
.build();
#[ext(pub, name = MultiaddrExt)]
impl Multiaddr {
/// If the multiaddress corresponds to a TCP address, extracts it
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
let mut ps = self.into_iter();
let ip = if let Some(p) = ps.next() {
match p {
Protocol::Ip4(ip) => IpAddr::V4(ip),
Protocol::Ip6(ip) => IpAddr::V6(ip),
_ => return None,
swarm
.listen_on("/ip6/::/udp/0/quic-v1".parse().expect("invalid multiaddr"))
.map_err(|_| ListenError)?;
swarm
.listen_on(
"/ip4/0.0.0.0/udp/0/quic-v1"
.parse()
.expect("invalid multiaddr"),
)
.map_err(|_| ListenError)?;
Ok(Self {
swarm,
to_client,
from_client,
namespace,
known_peers: HashMap::default(),
})
}
pub async fn run(&mut self) -> Result<(), ()> {
loop {
tokio::select! {
event = self.swarm.next() => self.handle_event(event.ok_or(())?).await?,
msg = self.from_client.recv() => self.handle_message(msg.ok_or(())?).await?,
}
}
}
async fn handle_message(&mut self, message: ToSwarm) -> Result<(), ()> {
match message {
ToSwarm::Message(topic, data) => {
if let Err(e) = self
.swarm
.behaviour_mut()
.gossipsub
.publish(TopicHash::from_raw(topic), data)
{
self.to_client
.send(FromSwarm::PublishError(e))
.await
.map_err(|_| ())?;
}
} else {
return None;
};
let Some(Protocol::Tcp(port)) = ps.next() else {
return None;
};
Some((ip, port))
}
ToSwarm::Subscribe(topic) => {
match self
.swarm
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic))
{
Ok(_) => {}
Err(gossipsub::SubscriptionError::NotAllowed) => {
unreachable!("subscription filter hit")
}
Err(gossipsub::SubscriptionError::PublishError(e)) => self
.to_client
.send(FromSwarm::PublishError(e))
.await
.map_err(|_| ())?,
}
}
ToSwarm::Unsubscribe(topic) => {
self.swarm
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
}
}
Ok(())
}
async fn handle_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
let SwarmEvent::Behaviour(event) = event else {
return Ok(());
};
match event {
BehaviourEvent::Gossipsub(gossipsub::Event::Message { message, .. }) => {
if let Some(source) = message.source {
self.to_client
.send(FromSwarm::Message(
source,
message.topic.into_string(),
message.data,
))
.await
.map_err(|_| ())?;
}
}
BehaviourEvent::Identify(identify::Event::Received { peer_id, info, .. }) => {
log::debug!(
"identify from {peer_id}: protocol_version='{}' agent_version='{}' (local namespace='{}')",
info.protocol_version,
info.agent_version,
self.namespace
);
if info.protocol_version == self.namespace {
self.passed_namespace(peer_id);
self.to_client
.send(FromSwarm::Discovered(peer_id))
.await
.map_err(|_| ())?;
} else {
self.failed_namespace(peer_id);
}
}
BehaviourEvent::Mdns(mdns::Event::Discovered(v)) => {
for (peer_id, addr) in v {
self.known_peers.entry(peer_id).or_default().push(addr);
}
for (peer_id, addrs) in &self.known_peers {
// dialopts handles rate limiting, we should check errors if we want to blacklist earlier
let _ = self
.swarm
.dial(DialOpts::peer_id(*peer_id).addresses(addrs.clone()).build());
}
}
BehaviourEvent::Mdns(mdns::Event::Expired(v)) => {
for (peer_id, addr) in v {
let addrs = self.known_peers.entry(peer_id).or_default();
addrs.retain(|a| *a != addr);
if addrs.is_empty() {
self.known_peers.remove(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(&peer_id);
self.to_client
.send(FromSwarm::Expired(peer_id))
.await
.map_err(|_| ())?;
}
}
}
_ => {}
}
Ok(())
}
fn passed_namespace(&mut self, peer_id: PeerId) {
self.swarm
.behaviour_mut()
.gossipsub
.remove_blacklisted_peer(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.add_explicit_peer(&peer_id);
}
fn failed_namespace(&mut self, peer_id: PeerId) {
self.swarm
.behaviour_mut()
.gossipsub
.blacklist_peer(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(&peer_id);
}
}
#[derive(NetworkBehaviour)]
pub struct Behaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
identify: identify::Behaviour,
}
impl Behaviour {
fn new(namespace: String, kp: &Keypair) -> Self {
let mdns = mdns::Behaviour::new(mdns::Config::default(), kp.public().to_peer_id())
.expect("mdns behaviour failed to build");
let identify =
identify::Behaviour::new(identify::Config::new_with_signed_peer_record(namespace, kp));
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(kp.clone()),
gossipsub::ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.validation_mode(gossipsub::ValidationMode::Strict)
.build()
.expect("invalid gossipsub configuration"),
)
.expect("gossipsub behaviour failed ot build");
Self {
gossipsub,
mdns,
identify,
}
}
}
// TODO: more tests
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, timeout};
fn make_peer(namespace: &str) -> (Peer, mpsc::Receiver<FromSwarm>, mpsc::Sender<ToSwarm>) {
let kp = Keypair::generate_ed25519();
let (to_client_tx, to_client_rx) = mpsc::channel(64);
let (to_peer_tx, to_peer_rx) = mpsc::channel(64);
let peer = Peer::new(namespace.to_string(), kp, to_client_tx, to_peer_rx)
.expect("Peer::new should succeed in tests");
(peer, to_client_rx, to_peer_tx)
}
async fn next_listen_addr(peer: &mut Peer) -> Multiaddr {
loop {
match peer.swarm.next().await {
Some(SwarmEvent::NewListenAddr { address, .. }) => return address,
Some(_) => {}
None => panic!("swarm stream ended unexpectedly"),
}
}
}
#[tokio::test]
async fn subscribe_and_unsubscribe_do_not_error() {
let (mut peer, mut events_rx, commands_tx) = make_peer("ns-test");
// Drive the swarm just enough to get at least one listen address event,
// so the background run loop has something initialized.
let _addr = next_listen_addr(&mut peer).await;
// Run the peer loop in the background.
let handle = tokio::spawn(async move {
let _ = peer.run().await;
});
commands_tx
.send(ToSwarm::Subscribe("topic-a".to_string()))
.await
.unwrap();
commands_tx
.send(ToSwarm::Unsubscribe("topic-a".to_string()))
.await
.unwrap();
// We don't *require* any FromSwarm events here; this is mainly a
// smoke test that the message-handling path doesn't panic/hang.
// Still, poll briefly to ensure the task is alive.
let _ = timeout(Duration::from_millis(200), events_rx.recv()).await;
// Shut down: dropping the command sender closes the channel, causing run() to return Err.
drop(commands_tx);
let _ = handle.await;
}
}

View File

@@ -1,273 +0,0 @@
use std::task::Poll;
use crate::swarm::transport::tcp_transport;
use crate::{alias, discovery};
pub use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::Stream;
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::{mpsc, oneshot};
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
///
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
/// even be, and how to inject the right version into this config/initialization. E.g. should
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
pub enum ToSwarm {
Unsubscribe {
topic: String,
result_sender: oneshot::Sender<bool>,
},
Subscribe {
topic: String,
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
},
Publish {
topic: String,
data: Vec<u8>,
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
},
}
pub enum FromSwarm {
Message {
from: PeerId,
topic: String,
data: Vec<u8>,
},
Discovered {
peer_id: PeerId,
},
Expired {
peer_id: PeerId,
},
}
#[pin_project::pin_project]
pub struct Swarm {
#[pin]
inner: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
}
impl Swarm {
fn on_message(&mut self, message: ToSwarm) {
match message {
ToSwarm::Subscribe {
topic,
result_sender,
} => {
// try to subscribe
let result = self
.inner
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic));
// send response oneshot
_ = result_sender.send(result)
}
ToSwarm::Unsubscribe {
topic,
result_sender,
} => {
// try to unsubscribe from the topic
let result = self
.inner
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
_ = result_sender.send(result)
}
ToSwarm::Publish {
topic,
data,
result_sender,
} => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = self
.inner
.behaviour_mut()
.gossipsub
.publish(gossipsub::IdentTopic::new(topic), data);
// send response oneshot (or exit if connection closed)
_ = result_sender.send(result)
}
}
}
}
impl Stream for Swarm {
type Item = FromSwarm;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.from_client.poll_recv(cx) {
Poll::Ready(Some(msg)) => self.on_message(msg),
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(swarm_event)) => match swarm_event {
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message:
gossipsub::Message {
source: Some(peer_id),
topic,
data,
..
},
..
})) => Poll::Ready(Some(FromSwarm::Message {
from: peer_id,
topic: topic.into_string(),
data,
})),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionEstablished { peer_id, .. },
)) => Poll::Ready(Some(FromSwarm::Discovered { peer_id })),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionClosed { peer_id, .. },
)) => Poll::Ready(Some(FromSwarm::Expired { peer_id })),
_ => Poll::Pending,
},
}
}
}
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(Swarm {
inner: swarm,
from_client,
})
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures_lite::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
use std::{env, sync::LazyLock};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
let builder = Sha3_256::new().update(b"exo_discovery_network");
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
let bytes = var.into_bytes();
builder.update(&bytes)
} else {
builder.update(NETWORK_VERSION)
}
.finalize()
});
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
/// also different-versioned instances of this same network.
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
async fn pnet_upgrade<TSocket>(
socket: TSocket,
_: impl Sized,
) -> Result<PnetOutput<TSocket>, PnetError>
where
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
use pnet::{PnetConfig, PreSharedKey};
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
.handshake(socket)
.await
}
/// TCP/IP transport layer configuration.
pub fn tcp_transport(
keypair: &identity::Keypair,
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
use libp2p::{
core::upgrade::Version,
tcp::{Config, tokio},
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;
// Noise is faster than TLS + we don't care much for security
let noise_config = noise::Config::new(keypair)?;
// Use default Yamux config for multiplexing
let yamux_config = yamux::Config::default();
// Create new Tokio-driven TCP/IP transport layer
let base_transport = tokio::Transport::new(tcp_config)
.and_then(pnet_upgrade)
.upgrade(upgrade_version)
.authenticate(noise_config)
.multiplex(yamux_config);
// Return boxed transport (to flatten complex type)
Ok(base_transport.boxed())
}
}
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
// build a gossipsub network behaviour
// => signed message authenticity + strict validation mode means the message-ID is
// automatically provided by gossipsub w/out needing to provide custom message-ID function
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -1,7 +0,0 @@
// maybe this will hold test in the future...??
#[cfg(test)]
mod tests {
#[test]
fn does_nothing() {}
}

View File

@@ -1,10 +1,11 @@
{ inputs, ... }:
{
perSystem =
{ inputs', pkgs, lib, ... }:
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
rustToolchain = inputs'.fenix.packages.stable.withComponents [
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"

View File

@@ -1,15 +0,0 @@
[package]
name = "util"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "util"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]

View File

@@ -1 +0,0 @@
pub mod wakerdeque;

View File

@@ -1,55 +0,0 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::task::{Context, Waker};
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
pub struct WakerDeque<T> {
waker: Option<Waker>,
deque: VecDeque<T>,
}
impl<T: Debug> Debug for WakerDeque<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deque.fmt(f)
}
}
impl<T> WakerDeque<T> {
pub fn new() -> Self {
Self {
waker: None,
deque: VecDeque::new(),
}
}
fn update(&mut self, cx: &mut Context<'_>) {
self.waker = Some(cx.waker().clone());
}
fn wake(&mut self) {
let Some(ref mut w) = self.waker else { return };
w.wake_by_ref();
self.waker = None;
}
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_front()
}
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_back()
}
pub fn push_front(&mut self, value: T) {
self.wake();
self.deque.push_front(value);
}
pub fn push_back(&mut self, value: T) {
self.wake();
self.deque.push_back(value);
}
}

View File

@@ -14,7 +14,6 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -47,7 +46,6 @@ class DownloadCoordinator:
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -63,13 +61,8 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -81,7 +74,6 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -101,7 +93,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -110,17 +101,13 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
@@ -183,11 +170,7 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -201,7 +184,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -209,20 +191,6 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)
@@ -238,7 +206,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -252,7 +219,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -287,7 +253,6 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -330,18 +295,11 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
@@ -350,9 +308,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -448,13 +448,12 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress, skip_internet
model_id, revision, path, target_dir, on_progress
)
except HuggingFaceAuthenticationError:
raise
@@ -488,14 +487,10 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -515,11 +510,6 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -824,7 +814,6 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -1,230 +0,0 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -1,4 +1,5 @@
import argparse
import importlib.metadata
import itertools
import multiprocessing as mp
import os
@@ -39,15 +40,14 @@ class Node:
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
node_id = NodeId(keypair.to_string())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
router = Router.create(keypair, namespace=args.namespace)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)
@@ -69,12 +69,11 @@ class Node:
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
if args.spawn_api:
if not args.no_api:
api = API(
node_id,
session_id,
@@ -134,13 +133,10 @@ class Node:
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -152,6 +148,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -225,7 +223,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -262,10 +259,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
logger.info(f"Namespace: {args.namespace}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
@@ -282,14 +276,13 @@ def main():
class Args(CamelCaseModel):
verbosity: int = 0
force_master: bool = False
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
verbosity: int
force_master: bool
no_api: bool
api_port: PositiveInt
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
namespace: str
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -319,14 +312,15 @@ class Args(CamelCaseModel):
)
parser.add_argument(
"--no-api",
action="store_false",
dest="spawn_api",
action="store_true",
help="Disable the API server for this node",
)
parser.add_argument(
"--api-port",
type=int,
dest="api_port",
default=52415,
help="Which port the API server will be available on",
)
parser.add_argument(
"--no-worker",
@@ -338,9 +332,9 @@ class Args(CamelCaseModel):
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
"--namespace",
default=importlib.metadata.version("exo"),
help="Set the EXO namespace to run multiple isolated clusters",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(

View File

@@ -85,7 +85,6 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
ImageSize,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -101,7 +100,6 @@ from exo.shared.types.api import (
TraceRankStats,
TraceResponse,
TraceStatsResponse,
normalize_image_size,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -753,11 +751,9 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"advanced_params": _ensure_seed(payload.advanced_params),
}
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
@@ -1013,13 +1009,12 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"stream": False,
"partial_images": 0,
"advanced_params": _ensure_seed(payload.advanced_params),
}
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
@@ -1040,7 +1035,7 @@ class API:
prompt: str,
model: ModelId,
n: int,
size: ImageSize,
size: str,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
@@ -1110,7 +1105,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str | None = Form(None),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
@@ -1136,7 +1131,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=normalize_image_size(size),
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
@@ -1172,7 +1167,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str | None = Form(None),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
@@ -1192,7 +1187,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=normalize_image_size(size),
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,

View File

@@ -396,7 +396,7 @@ class Master:
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
logger.trace(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
node_id = NodeId(keypair.to_string())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
sender_node_id = NodeId(f"{keypair.to_string()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -1,37 +1,9 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)
expired: bool

View File

@@ -1,5 +1,5 @@
from copy import copy
from itertools import count
from dataclasses import dataclass, field
from math import inf
from os import PathLike
from pathlib import Path
@@ -14,15 +14,14 @@ from anyio import (
)
from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyPeer,
)
from filelock import FileLock
from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
@@ -99,28 +98,32 @@ class TopicRouter[T: CamelCaseModel]:
)
@dataclass
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity))
_peer: PyPeer
topic_routers: dict[str, TopicRouter[CamelCaseModel]] = field(
init=False, default_factory=dict
)
networking_receiver: Receiver[tuple[str, bytes]] = field(init=False)
_tmp_networking_sender: Sender[tuple[str, bytes]] | None = field(init=False)
_tg: TaskGroup | None = None
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
def __post_init__(self):
self._tmp_networking_sender, self.networking_receiver = channel()
@classmethod
def create(cls, identity: Keypair, namespace: str) -> "Router":
return cls(_peer=PyPeer.new(identity, namespace))
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender
if send:
self._tmp_networking_sender = None
else:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
if self._tg is not None:
self._tg.start_soon(router.run)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
@@ -148,14 +151,18 @@ class Router:
async def run(self):
logger.debug("Starting Router")
try:
async def _peer_run():
await self._peer.run()
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
tg.start_soon(_peer_run)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
@@ -170,47 +177,58 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
await self._net.gossipsub_subscribe(topic)
await self._peer.subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
await self._net.gossipsub_unsubscribe(topic)
await self._peer.unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
try:
swarm_event = await self._peer.recv()
except ValueError:
logger.error("Message too large for gossipsub, dropped")
continue
except ConnectionError:
logger.error("All peer queues full, network overloaded")
continue
except RuntimeError:
break
cm = None
if (peer_id := swarm_event.downcast_discovered()) is not None:
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=False)
if (peer_id := swarm_event.downcast_expired()) is not None:
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=True)
if cm is not None:
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(cm)
continue
assert (msg := swarm_event.downcast_message()) is not None
_origin, topic, payload = msg
logger.debug(f"Received message on {topic} with payload {payload}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
await router.publish_bytes(payload)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
except NoPeersSubscribedToTopicError:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
await self._peer.send(topic, data)
except RuntimeError:
break
def get_node_id_keypair(
@@ -235,12 +253,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_bytes(protobuf_encoded)
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_bytes())
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -218,6 +218,11 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -258,6 +263,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -44,8 +44,7 @@ async def _refresh_card_cache():
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
if card.model_id not in _card_cache:
_card_cache[card.model_id] = card
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
@@ -183,7 +182,6 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
@@ -330,9 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
expired=False,
)
)

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().to_bytes())
queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -1,9 +1,9 @@
import time
from collections.abc import Generator
from typing import Annotated, Any, Literal, get_args
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
@@ -262,27 +262,6 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
ImageSize = Literal[
"auto",
"512x512",
"768x768",
"1024x768",
"768x1024",
"1024x1024",
"1024x1536",
"1536x1024",
]
def normalize_image_size(v: object) -> ImageSize:
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
if v is None:
return "auto"
if v not in get_args(ImageSize):
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
return v # pyright: ignore[reportReturnType]
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
@@ -302,7 +281,7 @@ class ImageGenerationTaskParams(BaseModel):
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: ImageSize = "auto"
size: str | None = "1024x1024"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
@@ -310,11 +289,6 @@ class ImageGenerationTaskParams(BaseModel):
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
@@ -331,18 +305,13 @@ class ImageEditsTaskParams(BaseModel):
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: ImageSize = "auto"
size: str | None = "1024x1024"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
if name == "image_data":

View File

@@ -4,13 +4,10 @@ from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]

View File

@@ -26,7 +26,6 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -1,4 +1,3 @@
import contextlib
import multiprocessing as mp
from dataclasses import dataclass, field
from math import inf
@@ -133,8 +132,7 @@ class MpSender[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.put(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==
@@ -206,8 +204,6 @@ class MpReceiver[T]:
def close(self) -> None:
if not self._state.closed.is_set():
self._state.closed.set()
with contextlib.suppress(Exception):
self._state.buffer.put_nowait(_MpEndOfStream())
self._state.buffer.close()
# == unique to Mp channels ==

View File

@@ -14,7 +14,6 @@ from exo.shared.types.api import (
ImageEditsTaskParams,
ImageGenerationStats,
ImageGenerationTaskParams,
ImageSize,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
@@ -24,9 +23,9 @@ from exo.shared.types.worker.runner_response import (
from exo.worker.engines.image.distributed_model import DistributedImageModel
def parse_size(size_str: ImageSize) -> tuple[int, int]:
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if size_str == "auto":
if not size_str:
return (1024, 1024)
try:
@@ -110,9 +109,6 @@ def generate_image(
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
if task.size == "auto":
with Image.open(image_path) as img:
width, height = img.size
for image_num in range(num_images):
# Increment seed for each image to ensure unique results

View File

@@ -163,14 +163,11 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(_cache.keys) # type: ignore
mx.eval(cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -310,9 +307,7 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
last = cache[-1] # type: ignore
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
return logits
@@ -338,9 +333,7 @@ def patch_tensor_model[T](model: T) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
last = cache[-1] # pyright: ignore[reportAny]
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
@@ -554,12 +547,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -590,18 +581,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -794,7 +779,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
@@ -907,7 +893,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
@@ -18,22 +17,10 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
@@ -77,7 +64,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None):
def __init__(self, group: mx.distributed.Group | None = None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
@@ -169,15 +156,15 @@ class KVPrefixCache:
best_length = 0
is_exact = False
# Find best cache match
# Find best cache
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(prompt_tokens, cached_prompt)
if length >= max_length - 1:
best_index, best_length = i, length
is_exact = True
break
if length > best_length:
best_index, best_length = i, length
if length == max_length:
is_exact = True
best_index, best_length = i, length
break
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
@@ -185,12 +172,11 @@ class KVPrefixCache:
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
has_ssm = has_non_kv_caches(self.caches[best_index])
target = (max_length - 1) if is_exact and not has_ssm else best_length
target = (max_length - 1) if is_exact else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
# No usable snapshot — need fresh cache
if restore_snap is None and has_ssm:
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
return make_kv_cache(model), prompt_tokens, None
prompt_cache = deepcopy(self.caches[best_index])
@@ -271,21 +257,10 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(prompt_tokens)
def _entry_length(
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
if hasattr(c, "offset"):
return c.offset
# For CacheList
if hasattr(c, "size"):
return int(c.size()) # type: ignore
return 0
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
return max(_entry_length(c) for c in cache)
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:

View File

@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def prefill(
@@ -57,7 +57,6 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -87,9 +86,6 @@ def prefill(
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
@@ -133,7 +129,7 @@ def prefill(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None,
group: mx.distributed.Group | None = None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -255,8 +251,8 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -309,9 +305,16 @@ def mlx_generate(
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
@@ -328,7 +331,6 @@ def mlx_generate(
think_start = tokenizer.think_start
think_end = tokenizer.think_end
logger.info("Starting decode")
mx_barrier(group)
for completion_tokens, out in enumerate(
@@ -436,14 +438,9 @@ def mlx_generate(
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if (
matched_index is not None
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index,

View File

@@ -285,15 +285,11 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None

View File

@@ -348,7 +348,7 @@ class Worker:
session=self.session_id,
event=event,
)
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
logger.trace(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe

View File

@@ -11,7 +11,6 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -589,11 +588,7 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel

View File

@@ -103,7 +103,7 @@ class RunnerSupervisor:
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
@@ -191,7 +191,7 @@ class RunnerSupervisor:
logger.info("Checking runner's status")
if self.runner_process.is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 5)
await to_thread.run_sync(self.runner_process.join, 1)
rc = self.runner_process.exitcode
logger.info(f"RunnerSupervisor exited with exit code {rc}")
if rc == 0:

View File

@@ -123,12 +123,7 @@ def run_gpt_oss_pipeline_device(
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:
@@ -199,8 +194,6 @@ def run_gpt_oss_tensor_parallel_device(
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(None)
cache = KVPrefixCache()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(None)
cache = KVPrefixCache()
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(None)
cache = KVPrefixCache()
cache.clear()
assert len(cache.prompts) == 0
@@ -142,9 +142,7 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
@@ -163,11 +161,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
@@ -180,11 +176,9 @@ class TestKVPrefixCacheWithModel:
)
assert matched_index == 0
# Exact match returns last token(s) — for models with SSM/rotating caches,
# snapshot availability constrains how far back we can trim, so remaining
# may be 1 or 2 tokens depending on the model.
assert len(remaining_tokens) >= 1
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
@@ -200,10 +194,10 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
model, tokenizer, make_sampler(0.0), short_tokens, cache
)
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
@@ -244,11 +238,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -284,11 +276,9 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -311,7 +301,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -328,7 +318,6 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
generated_tokens += 1
@@ -342,7 +331,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -358,7 +347,6 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -380,7 +368,7 @@ class TestKVPrefixCacheWithModel:
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -407,7 +395,6 @@ class TestKVPrefixCacheWithModel:
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
first_gen_time = time.perf_counter() - t0
@@ -440,7 +427,6 @@ class TestKVPrefixCacheWithModel:
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
second_gen_time = time.perf_counter() - t0
@@ -461,7 +447,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -476,7 +462,6 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -489,7 +474,6 @@ class TestKVPrefixCacheWithModel:
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
pass
@@ -500,7 +484,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(None)
kv_prefix_cache = KVPrefixCache()
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -513,7 +497,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -538,7 +522,7 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -1,297 +0,0 @@
import copy
import gc
import importlib
import json
import shutil
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
import pytest
from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
load_tokenizer_for_model_id,
)
HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
# ── Config reduction ──────────────────────────────────────────────────────── #
_REDUCE = {
"num_hidden_layers": 4,
"hidden_size": 256,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 512,
"moe_intermediate_size": 128,
"num_experts": 4,
"num_experts_per_tok": 2,
"n_routed_experts": 4,
"num_local_experts": 4,
"num_nextn_predict_layers": 0,
"first_k_dense_replace": 0,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"num_attention_groups": 4,
}
def _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:
result = dict(cfg)
for key, val in _REDUCE.items():
if key in result:
result[key] = val
return result
def _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:
result = _reduce_dict(cfg)
n_layers = cast(int, result.get("num_hidden_layers", 4))
if "text_config" in result and isinstance(result["text_config"], dict):
result["text_config"] = _reduce_dict(
cast(dict[str, Any], result["text_config"])
)
tc: dict[str, Any] = result["text_config"]
if "num_nextn_predict_layers" in tc:
tc["num_nextn_predict_layers"] = 0
if "layer_types" in result and isinstance(result["layer_types"], list):
result["layer_types"] = result["layer_types"][:n_layers]
if "attention_other_setting" in result and isinstance(
result["attention_other_setting"], dict
):
aos: dict[str, Any] = dict(
cast(dict[str, Any], result["attention_other_setting"])
)
if "num_attention_heads" in aos:
aos["num_attention_heads"] = result.get("num_attention_heads", 4)
if "num_attention_groups" in aos:
aos["num_attention_groups"] = result.get(
"num_attention_groups", cast(int, aos["num_attention_groups"])
)
result["attention_other_setting"] = aos
if "moe_layers_enum" in result and isinstance(result["moe_layers_enum"], str):
indices = [int(x) for x in result["moe_layers_enum"].split(",") if x.strip()]
valid = [i for i in indices if i < n_layers]
result["moe_layers_enum"] = ",".join(str(i) for i in valid) if valid else ""
return result
# ── Helpers ───────────────────────────────────────────────────────────────── #
def _find_snapshot(hub_name: str) -> Path | None:
model_dir = HF_CACHE / f"models--mlx-community--{hub_name}"
snaps = model_dir / "snapshots"
if not snaps.exists():
return None
children = sorted(snaps.iterdir())
return children[0] if children else None
def _copy_tokenizer(src: Path, dst: Path) -> None:
for f in src.iterdir():
name = f.name
if (
"tokeniz" in name.lower()
or "tiktoken" in name.lower()
or name.startswith("vocab")
or name.endswith(".jinja")
or "tool_declaration" in name
) and f.is_file():
shutil.copy2(f, dst / name)
def _build_model(module_name: str, cfg: dict[str, Any]) -> Model:
mod = importlib.import_module(f"mlx_lm.models.{module_name}")
args = mod.ModelArgs.from_dict(cfg) # pyright: ignore[reportAny]
model: nn.Module = mod.Model(args) # pyright: ignore[reportAny]
flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))
random_weights = [
(k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat
]
model.update(cast(dict[str, Any], tree_unflatten(random_weights)))
mx.eval(model.parameters())
return cast(Model, model)
def _collect_tokens(
model: Model,
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None,
) -> list[int]:
tokens: list[int] = []
for resp in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=None,
):
tokens.append(resp.token)
if resp.finish_reason is not None:
break
return tokens
# ── Architecture definitions ──────────────────────────────────────────────── #
@dataclass(frozen=True)
class ArchSpec:
name: str
hub_name: str
module: str
tokenizer_hub: str | None = None # fallback for models without bundled tokenizer
ARCHITECTURES: list[ArchSpec] = [
ArchSpec("llama", "Llama-3.2-1B-Instruct-4bit", "llama"),
ArchSpec("glm_moe_dsa", "GLM-5-MXFP4-Q8", "glm_moe_dsa"),
ArchSpec(
"glm4_moe", "GLM-4.5-Air-8bit", "glm4_moe", tokenizer_hub="GLM-4.7-8bit-gs32"
),
ArchSpec(
"glm4_moe_lite",
"GLM-4.7-Flash-8bit",
"glm4_moe_lite",
tokenizer_hub="GLM-4.7-8bit-gs32",
),
ArchSpec("glm4_moe_47", "GLM-4.7-8bit-gs32", "glm4_moe"),
ArchSpec("qwen3", "Qwen3-4B-Instruct-2507-4bit", "qwen3"),
ArchSpec("qwen3_moe", "Qwen3-30B-A3B-4bit", "qwen3_moe"),
ArchSpec("qwen3_next", "Qwen3-Next-80B-A3B-Thinking-4bit", "qwen3_next"),
ArchSpec("minimax", "MiniMax-M2.1-3bit", "minimax"),
ArchSpec("gpt_oss", "gpt-oss-20b-MXFP4-Q8", "gpt_oss"),
ArchSpec("step3p5", "Step-3.5-Flash-4bit", "step3p5"),
ArchSpec("kimi_k25", "Kimi-K2.5", "kimi_k25"),
]
def _arch_available(spec: ArchSpec) -> bool:
snap = _find_snapshot(spec.hub_name)
if snap is None:
return False
if spec.tokenizer_hub is not None:
return _find_snapshot(spec.tokenizer_hub) is not None
return True
def _make_task() -> TextGenerationTaskParams:
return TextGenerationTaskParams(
model=ModelId("test"),
input=[
InputMessage(
role="user",
content="Use the calculator to compute 1847 * 263 + 5921",
)
],
max_output_tokens=20,
temperature=0.0,
tools=[
{
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
}
],
)
# ── Test class ────────────────────────────────────────────────────────────── #
@pytest.mark.slow
class TestPrefixCacheArchitectures:
"""Verify prefix cache produces identical output to fresh generation for every architecture."""
@pytest.fixture(autouse=True)
def _cleanup(self):
yield
mx.clear_cache()
gc.collect()
@pytest.mark.parametrize(
"spec",
ARCHITECTURES,
ids=[a.name for a in ARCHITECTURES],
)
def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:
if not _arch_available(spec):
pytest.skip(f"Model {spec.hub_name} not cached locally")
snapshot = _find_snapshot(spec.hub_name)
assert snapshot is not None
tmpdir = Path(tempfile.mkdtemp(prefix=f"exo_test_{spec.name}_"))
try:
# Build reduced config
with open(snapshot / "config.json") as f:
cfg = cast(dict[str, Any], json.load(f))
reduced = _reduce_config(copy.deepcopy(cfg))
(tmpdir / "config.json").write_text(json.dumps(reduced))
# Copy tokenizer
tok_src = snapshot
if spec.tokenizer_hub is not None:
alt = _find_snapshot(spec.tokenizer_hub)
if alt is not None:
tok_src = alt
_copy_tokenizer(tok_src, tmpdir)
# Load tokenizer and model
model_id = ModelId(f"mlx-community/{spec.hub_name}")
tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)
mx.random.seed(0)
model = _build_model(spec.module, reduced)
task = _make_task()
prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)
# Run 1: fresh
mx.random.seed(42)
fresh = _collect_tokens(model, tokenizer, task, prompt, None)
assert len(fresh) > 0, "Fresh generation produced no tokens"
# Run 2: populate cache
kv = KVPrefixCache(None)
mx.random.seed(42)
populate = _collect_tokens(model, tokenizer, task, prompt, kv)
# Run 3: exact cache hit
mx.random.seed(42)
cached = _collect_tokens(model, tokenizer, task, prompt, kv)
assert fresh == populate, (
f"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}"
)
assert fresh == cached, (
f"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}"
)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

View File

@@ -343,16 +343,8 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
def contains(card: ModelCard, x: str):
return x in card.model_id.lower()
glm_model_cards = [
card
for card in await get_model_cards()
if contains(card, "glm")
and not contains(card, "-5")
and not contains(card, "4.7")
card for card in await get_model_cards() if "glm" in card.model_id.lower()
]
if not glm_model_cards:

Some files were not shown because too many files have changed in this diff Show More