mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
1 Commits
remove-nig
...
iroh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30021ea887 |
@@ -200,7 +200,7 @@ class Module(dict):
|
||||
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
|
||||
def update(self, parameters: dict, strict: bool = ...) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
|
||||
@@ -7,10 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from mlx.core import MX_ARRAY_TREE
|
||||
|
||||
def tree_map(
|
||||
fn: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -47,11 +44,11 @@ def tree_map(
|
||||
"""
|
||||
|
||||
def tree_map_with_path(
|
||||
fn: Callable[..., Any],
|
||||
fn: Callable,
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
path: str | None = ...,
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
path: Optional[Any] = ...,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -83,9 +80,9 @@ def tree_map_with_path(
|
||||
def tree_flatten(
|
||||
tree: Any,
|
||||
prefix: str = ...,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
|
||||
) -> list[tuple[str, Any]] | dict[str, Any]:
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
|
||||
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
@@ -121,7 +118,7 @@ def tree_flatten(
|
||||
the Python tree.
|
||||
"""
|
||||
|
||||
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
|
||||
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
32
Cargo.lock
generated
32
Cargo.lock
generated
@@ -673,17 +673,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "delegate"
|
||||
version = "0.13.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.10"
|
||||
@@ -887,20 +876,16 @@ dependencies = [
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures-lite",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -914,12 +899,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1028,10 +1007,7 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -2759,17 +2735,13 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4599,10 +4571,6 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "util"
|
||||
version = "0.0.1"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.19.0"
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -24,18 +23,18 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,47 +1,29 @@
|
||||
# type: ignore
|
||||
#!/usr/bin/env python3
|
||||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||||
|
||||
Tests whether models correctly:
|
||||
- Trigger tool calls when appropriate
|
||||
- Return valid JSON arguments matching function schemas
|
||||
- Handle multi-turn tool use (call -> result -> final answer)
|
||||
- Avoid calling tools when unnecessary
|
||||
|
||||
Start exo with a model first, then run:
|
||||
uv run python tool_call_eval.py --model <model-id>
|
||||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from harness import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -121,6 +103,154 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
@@ -139,6 +269,184 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
@@ -230,12 +538,76 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
add_common_instance_args(ap)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
@@ -248,6 +620,34 @@ def main() -> int:
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
@@ -257,6 +657,9 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
@@ -271,6 +674,17 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -305,10 +719,24 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
@@ -332,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
327
bench/harness.py
327
bench/harness.py
@@ -1,327 +0,0 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
raise
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if (m.get("name") or "").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def settle_and_fetch_placements(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
args: argparse.Namespace,
|
||||
settle_timeout: float = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
@@ -4,7 +4,6 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
# Tool definitions — each becomes an OpenAI function tool.
|
||||
# All scenarios get all tools unless they specify a `tools` list.
|
||||
|
||||
[tools.get_current_weather]
|
||||
description = "Get the current weather in a given location"
|
||||
required = ["location"]
|
||||
|
||||
[tools.get_current_weather.properties.location]
|
||||
type = "string"
|
||||
description = "City and state, e.g. San Francisco, CA"
|
||||
|
||||
[tools.get_current_weather.properties.unit]
|
||||
type = "string"
|
||||
enum = ["celsius", "fahrenheit"]
|
||||
description = "Temperature unit"
|
||||
|
||||
[tools.calculate]
|
||||
description = "Evaluate a mathematical expression and return the numeric result"
|
||||
required = ["expression"]
|
||||
|
||||
[tools.calculate.properties.expression]
|
||||
type = "string"
|
||||
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
|
||||
|
||||
[tools.search_products]
|
||||
description = "Search for products in a catalog by query, category, and price"
|
||||
required = ["query"]
|
||||
|
||||
[tools.search_products.properties.query]
|
||||
type = "string"
|
||||
description = "Search query string"
|
||||
|
||||
[tools.search_products.properties.category]
|
||||
type = "string"
|
||||
enum = ["electronics", "clothing", "food", "books"]
|
||||
description = "Product category to filter by"
|
||||
|
||||
[tools.search_products.properties.max_price]
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_simple"
|
||||
description = "Basic weather query -> get_current_weather"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather like in Tokyo right now?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_simple"
|
||||
description = "Math question -> calculate"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 3847 * 926 + 17293"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_with_filters"
|
||||
description = "Product search with category and price filter"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find me electronics under $50"
|
||||
|
||||
# -- Multi-turn: tool call then follow-up --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_multi_turn"
|
||||
description = "Weather query -> tool result -> natural language summary"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
temperature = "18C"
|
||||
condition = "partly cloudy"
|
||||
humidity = "65%"
|
||||
wind = "12 km/h NW"
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Paris?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_multi_turn"
|
||||
description = "Math query -> tool result -> model reports the answer"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
result = 491682
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 1847 * 263 + 5921"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_multi_turn"
|
||||
description = "Search query -> tool result -> model summarizes products"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Hands-On Machine Learning"
|
||||
price = 45.99
|
||||
rating = 4.8
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Deep Learning with Python"
|
||||
price = 39.99
|
||||
rating = 4.6
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Search for books about machine learning"
|
||||
|
||||
# -- Sequential tool calls --
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_same"
|
||||
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare the weather in Tokyo and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check both cities. Let me start with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_1"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_1"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_different"
|
||||
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll handle both. Let me check Berlin's weather first."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_2"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Berlin" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_2"
|
||||
content = '{"temperature": "12C", "condition": "rainy"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_three"
|
||||
description = "Two prior thinking+tool calls -> results -> model must make a third"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare weather in Tokyo, Paris, and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check all three cities. Starting with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_3"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_3"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "Got Tokyo. Now checking Paris."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_4"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Paris" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_joke"
|
||||
description = "Joke request should NOT trigger any tool"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Tell me a funny joke about cats."
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_factual"
|
||||
description = "Factual question answerable from training data"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What is the capital of Japan?"
|
||||
@@ -74,6 +74,7 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
|
||||
@@ -158,7 +158,6 @@
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "4bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 128666664960
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-6bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "6bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 185826705408
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -5,7 +5,6 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
@@ -27,7 +26,7 @@ networking = { workspace = true }
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
@@ -44,19 +43,11 @@ pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Tracing
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
TODO: do something here....
|
||||
|
||||
@@ -1,37 +1,22 @@
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
pin::{Pin, pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
pub struct AllowThreads<F>(pub(crate) F);
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Send,
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
|
||||
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,170 +1,5 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
mod allow_threading;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::ident_submodule;
|
||||
use crate::networking::networking_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = ResultExt)]
|
||||
impl<T, E> Result<T, E>
|
||||
where
|
||||
E: ToString,
|
||||
{
|
||||
fn pyerr(self) -> PyResult<T> {
|
||||
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
{
|
||||
AllowThreads::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Future> FutureExt for T {}
|
||||
|
||||
#[ext(pub, name = PyErrExt)]
|
||||
impl PyErr {
|
||||
fn receiver_channel_closed() -> Self {
|
||||
PyConnectionError::new_err("Receiver channel closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
match self {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
// write error back to python
|
||||
e.write_unraisable(py, None);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
mod allow_threading;
|
||||
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
|
||||
@@ -1,571 +0,0 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::{PyKeypair, PyPeerId};
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
||||
|
||||
ct = asyncio.create_task(_await_cons(h))
|
||||
mt = asyncio.create_task(_await_msg(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
|
||||
|
||||
async def _await_cons(h: NetworkingHandle):
|
||||
while True:
|
||||
c = await h.connection_update_recv()
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
|
||||
|
||||
async def _await_msg(h: NetworkingHandle):
|
||||
while True:
|
||||
m = await h.gossipsub_recv()
|
||||
print(f"PYTHON: message: {m}")
|
||||
@@ -5,7 +5,6 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
|
||||
@@ -13,20 +12,14 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -10,65 +6,4 @@ async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
@@ -1,382 +0,0 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
pub mod discovery;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ inputs', pkgs, lib, ... }:
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
|
||||
2
rust/rust-toolchain.toml
Normal file
2
rust/rust-toolchain.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -1,15 +0,0 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
@@ -1 +0,0 @@
|
||||
pub mod wakerdeque;
|
||||
@@ -1,55 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -4,13 +4,10 @@ from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass, field
|
||||
from math import inf
|
||||
@@ -133,8 +132,7 @@ class MpSender[T]:
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
self._state.closed.set()
|
||||
with contextlib.suppress(Exception):
|
||||
self._state.buffer.put_nowait(_MpEndOfStream())
|
||||
self._state.buffer.put(_MpEndOfStream())
|
||||
self._state.buffer.close()
|
||||
|
||||
# == unique to Mp channels ==
|
||||
@@ -206,8 +204,6 @@ class MpReceiver[T]:
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
self._state.closed.set()
|
||||
with contextlib.suppress(Exception):
|
||||
self._state.buffer.put_nowait(_MpEndOfStream())
|
||||
self._state.buffer.close()
|
||||
|
||||
# == unique to Mp channels ==
|
||||
|
||||
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
@@ -18,22 +17,10 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
@@ -77,7 +64,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None):
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
@@ -169,15 +156,15 @@ class KVPrefixCache:
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache match
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length >= max_length - 1:
|
||||
best_index, best_length = i, length
|
||||
is_exact = True
|
||||
break
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
@@ -185,12 +172,11 @@ class KVPrefixCache:
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
has_ssm = has_non_kv_caches(self.caches[best_index])
|
||||
target = (max_length - 1) if is_exact and not has_ssm else best_length
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_ssm:
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
@@ -271,21 +257,10 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def _entry_length(
|
||||
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
if hasattr(c, "offset"):
|
||||
return c.offset
|
||||
# For CacheList
|
||||
if hasattr(c, "size"):
|
||||
return int(c.size()) # type: ignore
|
||||
return 0
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
return max(_entry_length(c) for c in cache)
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
|
||||
@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
|
||||
|
||||
def prefill(
|
||||
@@ -57,7 +57,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -133,7 +133,7 @@ def prefill(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -255,8 +255,8 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -436,14 +436,9 @@ def mlx_generate(
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
hit_ratio = (
|
||||
prefix_hit_length / len(all_prompt_tokens)
|
||||
if len(all_prompt_tokens) > 0
|
||||
else 0.0
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
|
||||
@@ -292,8 +292,6 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
elif "gpt-oss" in model_id_lower:
|
||||
return [200002, 200012]
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
@@ -589,11 +588,7 @@ def parse_gpt_oss(
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
try:
|
||||
stream.process(response.token)
|
||||
except HarmonyError:
|
||||
logger.error("Encountered critical Harmony Error, returning early")
|
||||
return
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
@@ -103,7 +103,7 @@ class RunnerSupervisor:
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
@@ -123,12 +123,7 @@ def run_gpt_oss_pipeline_device(
|
||||
generated_text = ""
|
||||
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -199,8 +194,6 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,9 +142,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
@@ -163,11 +161,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
@@ -180,11 +176,9 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns last token(s) — for models with SSM/rotating caches,
|
||||
# snapshot availability constrains how far back we can trim, so remaining
|
||||
# may be 1 or 2 tokens depending on the model.
|
||||
assert len(remaining_tokens) >= 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
@@ -200,10 +194,10 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
@@ -244,11 +238,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -284,11 +276,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -311,7 +301,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -328,7 +318,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
@@ -342,7 +331,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -358,7 +347,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -380,7 +368,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -407,7 +395,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
@@ -440,7 +427,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
@@ -461,7 +447,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -476,7 +462,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -489,7 +474,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -500,7 +484,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -513,7 +497,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -538,7 +522,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
import copy
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
load_tokenizer_for_model_id,
|
||||
)
|
||||
|
||||
HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
|
||||
|
||||
# ── Config reduction ──────────────────────────────────────────────────────── #
|
||||
|
||||
_REDUCE = {
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 256,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 512,
|
||||
"moe_intermediate_size": 128,
|
||||
"num_experts": 4,
|
||||
"num_experts_per_tok": 2,
|
||||
"n_routed_experts": 4,
|
||||
"num_local_experts": 4,
|
||||
"num_nextn_predict_layers": 0,
|
||||
"first_k_dense_replace": 0,
|
||||
"linear_num_key_heads": 2,
|
||||
"linear_num_value_heads": 2,
|
||||
"num_attention_groups": 4,
|
||||
}
|
||||
|
||||
|
||||
def _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
result = dict(cfg)
|
||||
for key, val in _REDUCE.items():
|
||||
if key in result:
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
result = _reduce_dict(cfg)
|
||||
n_layers = cast(int, result.get("num_hidden_layers", 4))
|
||||
|
||||
if "text_config" in result and isinstance(result["text_config"], dict):
|
||||
result["text_config"] = _reduce_dict(
|
||||
cast(dict[str, Any], result["text_config"])
|
||||
)
|
||||
tc: dict[str, Any] = result["text_config"]
|
||||
if "num_nextn_predict_layers" in tc:
|
||||
tc["num_nextn_predict_layers"] = 0
|
||||
|
||||
if "layer_types" in result and isinstance(result["layer_types"], list):
|
||||
result["layer_types"] = result["layer_types"][:n_layers]
|
||||
|
||||
if "attention_other_setting" in result and isinstance(
|
||||
result["attention_other_setting"], dict
|
||||
):
|
||||
aos: dict[str, Any] = dict(
|
||||
cast(dict[str, Any], result["attention_other_setting"])
|
||||
)
|
||||
if "num_attention_heads" in aos:
|
||||
aos["num_attention_heads"] = result.get("num_attention_heads", 4)
|
||||
if "num_attention_groups" in aos:
|
||||
aos["num_attention_groups"] = result.get(
|
||||
"num_attention_groups", cast(int, aos["num_attention_groups"])
|
||||
)
|
||||
result["attention_other_setting"] = aos
|
||||
|
||||
if "moe_layers_enum" in result and isinstance(result["moe_layers_enum"], str):
|
||||
indices = [int(x) for x in result["moe_layers_enum"].split(",") if x.strip()]
|
||||
valid = [i for i in indices if i < n_layers]
|
||||
result["moe_layers_enum"] = ",".join(str(i) for i in valid) if valid else ""
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────── #
|
||||
|
||||
|
||||
def _find_snapshot(hub_name: str) -> Path | None:
|
||||
model_dir = HF_CACHE / f"models--mlx-community--{hub_name}"
|
||||
snaps = model_dir / "snapshots"
|
||||
if not snaps.exists():
|
||||
return None
|
||||
children = sorted(snaps.iterdir())
|
||||
return children[0] if children else None
|
||||
|
||||
|
||||
def _copy_tokenizer(src: Path, dst: Path) -> None:
|
||||
for f in src.iterdir():
|
||||
name = f.name
|
||||
if (
|
||||
"tokeniz" in name.lower()
|
||||
or "tiktoken" in name.lower()
|
||||
or name.startswith("vocab")
|
||||
or name.endswith(".jinja")
|
||||
or "tool_declaration" in name
|
||||
) and f.is_file():
|
||||
shutil.copy2(f, dst / name)
|
||||
|
||||
|
||||
def _build_model(module_name: str, cfg: dict[str, Any]) -> Model:
|
||||
mod = importlib.import_module(f"mlx_lm.models.{module_name}")
|
||||
args = mod.ModelArgs.from_dict(cfg) # pyright: ignore[reportAny]
|
||||
model: nn.Module = mod.Model(args) # pyright: ignore[reportAny]
|
||||
flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))
|
||||
random_weights = [
|
||||
(k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat
|
||||
]
|
||||
model.update(cast(dict[str, Any], tree_unflatten(random_weights)))
|
||||
mx.eval(model.parameters())
|
||||
return cast(Model, model)
|
||||
|
||||
|
||||
def _collect_tokens(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
) -> list[int]:
|
||||
tokens: list[int] = []
|
||||
for resp in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
tokens.append(resp.token)
|
||||
if resp.finish_reason is not None:
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Architecture definitions ──────────────────────────────────────────────── #
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchSpec:
|
||||
name: str
|
||||
hub_name: str
|
||||
module: str
|
||||
tokenizer_hub: str | None = None # fallback for models without bundled tokenizer
|
||||
|
||||
|
||||
ARCHITECTURES: list[ArchSpec] = [
|
||||
ArchSpec("llama", "Llama-3.2-1B-Instruct-4bit", "llama"),
|
||||
ArchSpec("glm_moe_dsa", "GLM-5-MXFP4-Q8", "glm_moe_dsa"),
|
||||
ArchSpec(
|
||||
"glm4_moe", "GLM-4.5-Air-8bit", "glm4_moe", tokenizer_hub="GLM-4.7-8bit-gs32"
|
||||
),
|
||||
ArchSpec(
|
||||
"glm4_moe_lite",
|
||||
"GLM-4.7-Flash-8bit",
|
||||
"glm4_moe_lite",
|
||||
tokenizer_hub="GLM-4.7-8bit-gs32",
|
||||
),
|
||||
ArchSpec("glm4_moe_47", "GLM-4.7-8bit-gs32", "glm4_moe"),
|
||||
ArchSpec("qwen3", "Qwen3-4B-Instruct-2507-4bit", "qwen3"),
|
||||
ArchSpec("qwen3_moe", "Qwen3-30B-A3B-4bit", "qwen3_moe"),
|
||||
ArchSpec("qwen3_next", "Qwen3-Next-80B-A3B-Thinking-4bit", "qwen3_next"),
|
||||
ArchSpec("minimax", "MiniMax-M2.1-3bit", "minimax"),
|
||||
ArchSpec("gpt_oss", "gpt-oss-20b-MXFP4-Q8", "gpt_oss"),
|
||||
ArchSpec("step3p5", "Step-3.5-Flash-4bit", "step3p5"),
|
||||
ArchSpec("kimi_k25", "Kimi-K2.5", "kimi_k25"),
|
||||
]
|
||||
|
||||
|
||||
def _arch_available(spec: ArchSpec) -> bool:
|
||||
snap = _find_snapshot(spec.hub_name)
|
||||
if snap is None:
|
||||
return False
|
||||
if spec.tokenizer_hub is not None:
|
||||
return _find_snapshot(spec.tokenizer_hub) is not None
|
||||
return True
|
||||
|
||||
|
||||
def _make_task() -> TextGenerationTaskParams:
|
||||
return TextGenerationTaskParams(
|
||||
model=ModelId("test"),
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content="Use the calculator to compute 1847 * 263 + 5921",
|
||||
)
|
||||
],
|
||||
max_output_tokens=20,
|
||||
temperature=0.0,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Evaluate a mathematical expression",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Test class ────────────────────────────────────────────────────────────── #
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestPrefixCacheArchitectures:
|
||||
"""Verify prefix cache produces identical output to fresh generation for every architecture."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup(self):
|
||||
yield
|
||||
mx.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
ARCHITECTURES,
|
||||
ids=[a.name for a in ARCHITECTURES],
|
||||
)
|
||||
def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:
|
||||
if not _arch_available(spec):
|
||||
pytest.skip(f"Model {spec.hub_name} not cached locally")
|
||||
|
||||
snapshot = _find_snapshot(spec.hub_name)
|
||||
assert snapshot is not None
|
||||
|
||||
tmpdir = Path(tempfile.mkdtemp(prefix=f"exo_test_{spec.name}_"))
|
||||
try:
|
||||
# Build reduced config
|
||||
with open(snapshot / "config.json") as f:
|
||||
cfg = cast(dict[str, Any], json.load(f))
|
||||
reduced = _reduce_config(copy.deepcopy(cfg))
|
||||
(tmpdir / "config.json").write_text(json.dumps(reduced))
|
||||
|
||||
# Copy tokenizer
|
||||
tok_src = snapshot
|
||||
if spec.tokenizer_hub is not None:
|
||||
alt = _find_snapshot(spec.tokenizer_hub)
|
||||
if alt is not None:
|
||||
tok_src = alt
|
||||
_copy_tokenizer(tok_src, tmpdir)
|
||||
|
||||
# Load tokenizer and model
|
||||
model_id = ModelId(f"mlx-community/{spec.hub_name}")
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)
|
||||
mx.random.seed(0)
|
||||
model = _build_model(spec.module, reduced)
|
||||
|
||||
task = _make_task()
|
||||
prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)
|
||||
|
||||
# Run 1: fresh
|
||||
mx.random.seed(42)
|
||||
fresh = _collect_tokens(model, tokenizer, task, prompt, None)
|
||||
assert len(fresh) > 0, "Fresh generation produced no tokens"
|
||||
|
||||
# Run 2: populate cache
|
||||
kv = KVPrefixCache(None)
|
||||
mx.random.seed(42)
|
||||
populate = _collect_tokens(model, tokenizer, task, prompt, kv)
|
||||
|
||||
# Run 3: exact cache hit
|
||||
mx.random.seed(42)
|
||||
cached = _collect_tokens(model, tokenizer, task, prompt, kv)
|
||||
|
||||
assert fresh == populate, (
|
||||
f"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}"
|
||||
)
|
||||
assert fresh == cached, (
|
||||
f"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
@@ -343,16 +343,8 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
|
||||
def contains(card: ModelCard, x: str):
|
||||
return x in card.model_id.lower()
|
||||
|
||||
glm_model_cards = [
|
||||
card
|
||||
for card in await get_model_cards()
|
||||
if contains(card, "glm")
|
||||
and not contains(card, "-5")
|
||||
and not contains(card, "4.7")
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_gpt_oss
|
||||
|
||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||
# These are stable since they come from the model's vocabulary.
|
||||
_CHANNEL = 200005 # <|channel|>
|
||||
_START = 200006 # <|start|>
|
||||
_MESSAGE = 200008 # <|message|>
|
||||
_CALL = 200012 # <|call|>
|
||||
_END = 200007 # <|end|>
|
||||
_ASSISTANT = 173781 # "assistant"
|
||||
|
||||
# fmt: off
|
||||
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_A_TOKENS: list[tuple[int, str]] = [
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_B_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
|
||||
# Full analysis-then-tool-call as the model actually generates it.
|
||||
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(35644, "analysis"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(12845, "Let"),
|
||||
(668, " me"),
|
||||
(2411, " think"),
|
||||
(1078, " about"),
|
||||
(495, " this"),
|
||||
(13, "."),
|
||||
(_END, "<|end|>"),
|
||||
# Model generates a new message header for the tool call:
|
||||
(_START, "<|start|>"),
|
||||
(_ASSISTANT, "assistant"),
|
||||
*FORMAT_B_TOKENS,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _make_gen_responses(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse]:
|
||||
"""Build GenerationResponse list from (token_id, text) pairs."""
|
||||
responses: list[GenerationResponse] = []
|
||||
for i, (tid, text) in enumerate(tokens):
|
||||
is_last = i == len(tokens) - 1
|
||||
responses.append(
|
||||
GenerationResponse(
|
||||
text=text,
|
||||
token=tid,
|
||||
finish_reason="stop" if is_last else None,
|
||||
usage=None,
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse | ToolCallResponse]:
|
||||
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
|
||||
|
||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||
yield from _make_gen_responses(tokens)
|
||||
|
||||
return list(parse_gpt_oss(_gen()))
|
||||
|
||||
|
||||
def _get_tool_call(
|
||||
results: list[GenerationResponse | ToolCallResponse],
|
||||
) -> ToolCallResponse:
|
||||
"""Extract the single ToolCallResponse from results."""
|
||||
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
class TestParseGptOssRecipientPlacement:
|
||||
"""Both Harmony recipient placements must produce identical tool calls."""
|
||||
|
||||
def test_format_a_yields_tool_call(self):
|
||||
results = _collect(FORMAT_A_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_format_b_yields_tool_call(self):
|
||||
results = _collect(FORMAT_B_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_both_formats_produce_identical_tool_calls(self):
|
||||
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
|
||||
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
|
||||
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
|
||||
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
|
||||
|
||||
|
||||
class TestParseGptOssThinkingThenToolCall:
|
||||
"""Analysis (thinking) followed by a tool call must yield both."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
eval_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
|
||||
echo "running eval for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
|
||||
>>"./bench/$commit/${model//\//--}-eval.json"
|
||||
echo
|
||||
done
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -447,7 +447,6 @@ name = "exo-bench"
|
||||
version = "0.1.0"
|
||||
source = { editable = "bench" }
|
||||
dependencies = [
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -457,7 +456,6 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
|
||||
Reference in New Issue
Block a user