mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
1 Commits
leo/fix-to
...
fix/mp-rec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faac462d9a |
@@ -1,46 +0,0 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
123
Cargo.lock
generated
123
Cargo.lock
generated
@@ -141,6 +141,12 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
@@ -298,6 +304,19 @@ version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bimap"
|
||||
version = "0.6.3"
|
||||
@@ -497,6 +516,15 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -718,6 +746,29 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn 2.0.111",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"impl-trait-for-tuples",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"thiserror 2.0.17",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
@@ -1584,6 +1640,17 @@ dependencies = [
|
||||
"xmltree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "impl-trait-for-tuples"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
@@ -1762,6 +1829,12 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "libp2p"
|
||||
version = "0.56.0"
|
||||
@@ -2751,13 +2824,16 @@ name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
@@ -2842,6 +2918,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -3192,14 +3279,28 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"either",
|
||||
"hashbrown 0.16.1",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"memoffset",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"rust_decimal",
|
||||
"smallvec",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
@@ -3640,6 +3741,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
@@ -4504,12 +4615,24 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "1.3.0"
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -26,21 +26,49 @@ opt-level = 3
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
11
README.md
11
README.md
@@ -72,23 +72,16 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
status = .stopped
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,29 +4,26 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from harness import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -106,6 +103,154 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
@@ -124,6 +269,184 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
@@ -215,12 +538,76 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
add_common_instance_args(ap)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
@@ -233,6 +620,34 @@ def main() -> int:
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
@@ -242,6 +657,9 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
@@ -256,6 +674,17 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -290,10 +719,24 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
@@ -317,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
327
bench/harness.py
327
bench/harness.py
@@ -1,327 +0,0 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
raise
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if (m.get("name") or "").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def settle_and_fetch_placements(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
args: argparse.Namespace,
|
||||
settle_timeout: float = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
@@ -4,7 +4,6 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
# Tool definitions — each becomes an OpenAI function tool.
|
||||
# All scenarios get all tools unless they specify a `tools` list.
|
||||
|
||||
[tools.get_current_weather]
|
||||
description = "Get the current weather in a given location"
|
||||
required = ["location"]
|
||||
|
||||
[tools.get_current_weather.properties.location]
|
||||
type = "string"
|
||||
description = "City and state, e.g. San Francisco, CA"
|
||||
|
||||
[tools.get_current_weather.properties.unit]
|
||||
type = "string"
|
||||
enum = ["celsius", "fahrenheit"]
|
||||
description = "Temperature unit"
|
||||
|
||||
[tools.calculate]
|
||||
description = "Evaluate a mathematical expression and return the numeric result"
|
||||
required = ["expression"]
|
||||
|
||||
[tools.calculate.properties.expression]
|
||||
type = "string"
|
||||
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
|
||||
|
||||
[tools.search_products]
|
||||
description = "Search for products in a catalog by query, category, and price"
|
||||
required = ["query"]
|
||||
|
||||
[tools.search_products.properties.query]
|
||||
type = "string"
|
||||
description = "Search query string"
|
||||
|
||||
[tools.search_products.properties.category]
|
||||
type = "string"
|
||||
enum = ["electronics", "clothing", "food", "books"]
|
||||
description = "Product category to filter by"
|
||||
|
||||
[tools.search_products.properties.max_price]
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_simple"
|
||||
description = "Basic weather query -> get_current_weather"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather like in Tokyo right now?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_simple"
|
||||
description = "Math question -> calculate"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 3847 * 926 + 17293"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_with_filters"
|
||||
description = "Product search with category and price filter"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find me electronics under $50"
|
||||
|
||||
# -- Multi-turn: tool call then follow-up --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_multi_turn"
|
||||
description = "Weather query -> tool result -> natural language summary"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
temperature = "18C"
|
||||
condition = "partly cloudy"
|
||||
humidity = "65%"
|
||||
wind = "12 km/h NW"
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Paris?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_multi_turn"
|
||||
description = "Math query -> tool result -> model reports the answer"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
result = 491682
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 1847 * 263 + 5921"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_multi_turn"
|
||||
description = "Search query -> tool result -> model summarizes products"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Hands-On Machine Learning"
|
||||
price = 45.99
|
||||
rating = 4.8
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Deep Learning with Python"
|
||||
price = 39.99
|
||||
rating = 4.6
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Search for books about machine learning"
|
||||
|
||||
# -- Sequential tool calls --
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_same"
|
||||
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare the weather in Tokyo and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check both cities. Let me start with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_1"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_1"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_different"
|
||||
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll handle both. Let me check Berlin's weather first."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_2"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Berlin" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_2"
|
||||
content = '{"temperature": "12C", "condition": "rainy"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_three"
|
||||
description = "Two prior thinking+tool calls -> results -> model must make a third"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare weather in Tokyo, Paris, and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check all three cities. Starting with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_3"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_3"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "Got Tokyo. Now checking Paris."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_4"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Paris" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_joke"
|
||||
description = "Joke request should NOT trigger any tool"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Tell me a funny joke about cats."
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_factual"
|
||||
description = "Factual question answerable from training data"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What is the capital of Japan?"
|
||||
@@ -103,7 +103,7 @@
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking_toggle") && caps.includes("text");
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
|
||||
@@ -59,14 +59,13 @@
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -177,90 +176,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size.toUpperCase()}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -310,7 +311,7 @@
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition()
|
||||
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
|
||||
@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "auto"
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1536"
|
||||
| "1536x1024";
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -337,7 +336,7 @@ export interface EditingImage {
|
||||
}
|
||||
|
||||
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "auto",
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
|
||||
@@ -115,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260218+14841977"; in
|
||||
version = let v = "0.30.6"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
|
||||
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.7",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -64,7 +64,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -58,21 +58,6 @@
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
@@ -89,25 +74,14 @@
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
@@ -158,7 +132,6 @@
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "4bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "8bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "5bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2.5"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "3bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
|
||||
@@ -25,17 +25,17 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
@@ -45,6 +45,8 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
@@ -52,11 +54,24 @@ tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
#tracing-subscriber = "0.3"
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -6,7 +6,7 @@ use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
pin::{Pin, pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
@@ -33,6 +33,8 @@ where
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
@@ -24,6 +25,7 @@ use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -34,10 +36,14 @@ pub(crate) mod r#const {
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
@@ -45,6 +51,7 @@ pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
@@ -55,7 +62,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +98,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -168,6 +175,24 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
|
||||
@@ -11,9 +11,9 @@ use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt a
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
@@ -155,6 +155,7 @@ async fn networking_task(
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
@@ -484,7 +485,7 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
|
||||
@@ -19,6 +19,8 @@ either = { workspace = true }
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
@@ -27,6 +29,11 @@ futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
@@ -34,4 +41,4 @@ keccak-const = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -24,8 +24,8 @@ use libp2p::{
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,19 @@
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
@@ -42,3 +54,11 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,6 @@ class DownloadCoordinator:
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool = False
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
@@ -63,8 +62,6 @@ class DownloadCoordinator:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
@@ -110,17 +107,13 @@ class DownloadCoordinator:
|
||||
self._last_progress_time[model_id] = current_time()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info(
|
||||
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
|
||||
)
|
||||
if not self.offline:
|
||||
self._test_internet_connection()
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
@@ -209,20 +202,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
return
|
||||
|
||||
if self.offline:
|
||||
logger.warning(
|
||||
f"Offline mode: model {model_id} is not fully available locally, cannot download"
|
||||
)
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=f"Model files not found locally in offline mode: {model_id}",
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
|
||||
@@ -448,13 +448,12 @@ async def download_file_with_retry(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress, skip_internet
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
@@ -488,14 +487,10 @@ async def _download_file(
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
if skip_internet:
|
||||
return target_path
|
||||
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
@@ -515,11 +510,6 @@ async def _download_file(
|
||||
)
|
||||
return target_path
|
||||
|
||||
if skip_internet:
|
||||
raise FileNotFoundError(
|
||||
f"File {path} not found locally and cannot download in offline mode"
|
||||
)
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -824,7 +814,6 @@ async def download_shard(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
on_connection_lost=on_connection_lost,
|
||||
skip_internet=skip_internet,
|
||||
)
|
||||
|
||||
if not skip_download:
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
"""Tests for offline/air-gapped mode."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
download_file_with_retry,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.worker.downloads import FileListEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestDownloadFileOffline:
|
||||
"""Tests for _download_file with skip_internet=True."""
|
||||
|
||||
async def test_returns_local_file_without_http_verification(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists locally, return it immediately
|
||||
without making any HTTP calls (no file_meta verification)."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"model weights data")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_raises_file_not_found_for_missing_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file does NOT exist locally,
|
||||
raise FileNotFoundError instead of attempting download."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="offline mode"):
|
||||
await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"missing_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
async def test_returns_local_file_in_subdirectory(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists in a subdirectory,
|
||||
return it without HTTP calls."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
subdir = target_dir / "transformer"
|
||||
await aios.makedirs(subdir, exist_ok=True)
|
||||
|
||||
local_file = subdir / "diffusion_pytorch_model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"weights")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"transformer/diffusion_pytorch_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
|
||||
class TestDownloadFileWithRetryOffline:
|
||||
"""Tests for download_file_with_retry with skip_internet=True."""
|
||||
|
||||
async def test_propagates_skip_internet_to_download_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify skip_internet is passed through to _download_file."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "config.json"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b'{"model_type": "qwen2"}')
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_file_not_found_does_not_retry(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""FileNotFoundError from offline mode should not trigger retries."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"nonexistent.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
|
||||
class TestFetchFileListOffline:
|
||||
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
|
||||
|
||||
async def test_uses_cached_file_list(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and cache file exists, use it without network."""
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=200),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
assert result == cached_list
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
async def test_falls_back_to_local_directory_scan(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and no cache but local files exist,
|
||||
build file list from local directory."""
|
||||
import json
|
||||
|
||||
model_dir = temp_models_dir / model_id.normalize()
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(model_dir / "config.json", "w") as f:
|
||||
await f.write('{"model_type": "qwen2"}')
|
||||
|
||||
index_data = {
|
||||
"metadata": {},
|
||||
"weight_map": {"model.layers.0.weight": "model.safetensors"},
|
||||
}
|
||||
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
|
||||
await f.write(json.dumps(index_data))
|
||||
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
|
||||
await f.write(b"x" * 500)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
mock_fetch.assert_not_called()
|
||||
paths = {entry.path for entry in result}
|
||||
assert "config.json" in paths
|
||||
assert "model.safetensors" in paths
|
||||
|
||||
async def test_raises_when_no_cache_and_no_local_files(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and neither cache nor local files exist,
|
||||
raise FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="No internet"):
|
||||
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
|
||||
@@ -39,7 +39,6 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -69,7 +68,6 @@ class Node:
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -134,13 +132,10 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
args.offline,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -152,6 +147,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -225,7 +222,6 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -264,9 +260,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
@@ -289,7 +282,6 @@ class Args(CamelCaseModel):
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -337,11 +329,6 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -85,7 +85,6 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
ImageSize,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -101,7 +100,6 @@ from exo.shared.types.api import (
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
normalize_image_size,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -753,11 +751,9 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1013,13 +1009,12 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"stream": False,
|
||||
"partial_images": 0,
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1040,7 +1035,7 @@ class API:
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: ImageSize,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
@@ -1110,7 +1105,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str | None = Form(None),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
@@ -1136,7 +1131,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=normalize_image_size(size),
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream_bool,
|
||||
@@ -1172,7 +1167,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str | None = Form(None),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
@@ -1192,7 +1187,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=normalize_image_size(size),
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
|
||||
@@ -44,8 +44,7 @@ async def _refresh_card_cache():
|
||||
async for toml_file in path.rglob("*.toml"):
|
||||
try:
|
||||
card = await ModelCard.load_from_path(toml_file)
|
||||
if card.model_id not in _card_cache:
|
||||
_card_cache[card.model_id] = card
|
||||
_card_cache[card.model_id] = card
|
||||
except (ValidationError, TOMLKitError):
|
||||
pass
|
||||
|
||||
@@ -183,7 +182,6 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal, get_args
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -262,27 +262,6 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
ImageSize = Literal[
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
]
|
||||
|
||||
|
||||
def normalize_image_size(v: object) -> ImageSize:
|
||||
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
|
||||
if v is None:
|
||||
return "auto"
|
||||
if v not in get_args(ImageSize):
|
||||
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
|
||||
return v # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
@@ -302,7 +281,7 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: ImageSize = "auto"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
@@ -310,11 +289,6 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
@@ -331,18 +305,13 @@ class ImageEditsTaskParams(BaseModel):
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: ImageSize = "auto"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
if name == "image_data":
|
||||
|
||||
@@ -4,13 +4,10 @@ from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -204,6 +204,10 @@ class MpReceiver[T]:
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
self._state.closed.set()
|
||||
try: # noqa: SIM105
|
||||
self._state.buffer.put_nowait(_MpEndOfStream())
|
||||
except Exception:
|
||||
pass
|
||||
self._state.buffer.close()
|
||||
|
||||
# == unique to Mp channels ==
|
||||
|
||||
@@ -14,7 +14,6 @@ from exo.shared.types.api import (
|
||||
ImageEditsTaskParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ImageSize,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
@@ -24,9 +23,9 @@ from exo.shared.types.worker.runner_response import (
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: ImageSize) -> tuple[int, int]:
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if size_str == "auto":
|
||||
if not size_str:
|
||||
return (1024, 1024)
|
||||
|
||||
try:
|
||||
@@ -110,9 +109,6 @@ def generate_image(
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
if task.size == "auto":
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
|
||||
@@ -163,14 +163,11 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
|
||||
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
if self.is_prefill:
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(_cache.keys) # type: ignore
|
||||
mx.eval(cache.keys) # type: ignore
|
||||
|
||||
if not self.is_prefill:
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
@@ -310,9 +307,7 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
last = cache[-1] # type: ignore
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
|
||||
@@ -338,9 +333,7 @@ def patch_tensor_model[T](model: T) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
last = cache[-1] # pyright: ignore[reportAny]
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
@@ -554,12 +547,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
@@ -590,18 +581,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -794,7 +779,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
@@ -907,7 +893,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
@@ -18,22 +17,10 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
@@ -77,7 +64,7 @@ def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None):
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
@@ -169,15 +156,15 @@ class KVPrefixCache:
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache match
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length >= max_length - 1:
|
||||
best_index, best_length = i, length
|
||||
is_exact = True
|
||||
break
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
@@ -185,12 +172,11 @@ class KVPrefixCache:
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
has_ssm = has_non_kv_caches(self.caches[best_index])
|
||||
target = (max_length - 1) if is_exact and not has_ssm else best_length
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_ssm:
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
@@ -271,21 +257,10 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def _entry_length(
|
||||
c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
if hasattr(c, "offset"):
|
||||
return c.offset
|
||||
# For CacheList
|
||||
if hasattr(c, "size"):
|
||||
return int(c.size()) # type: ignore
|
||||
return 0
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
return max(_entry_length(c) for c in cache)
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
|
||||
@@ -48,7 +48,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
|
||||
|
||||
def prefill(
|
||||
@@ -57,7 +57,6 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -87,9 +86,6 @@ def prefill(
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
@@ -133,7 +129,7 @@ def prefill(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -255,8 +251,8 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -309,9 +305,16 @@ def mlx_generate(
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
@@ -328,7 +331,6 @@ def mlx_generate(
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
|
||||
logger.info("Starting decode")
|
||||
mx_barrier(group)
|
||||
|
||||
for completion_tokens, out in enumerate(
|
||||
@@ -436,14 +438,9 @@ def mlx_generate(
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
hit_ratio = (
|
||||
prefix_hit_length / len(all_prompt_tokens)
|
||||
if len(all_prompt_tokens) > 0
|
||||
else 0.0
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
|
||||
@@ -285,15 +285,11 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# For GLM-5 and GLM-4.7
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
elif "gpt-oss" in model_id_lower:
|
||||
return [200002, 200012]
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
@@ -589,11 +588,7 @@ def parse_gpt_oss(
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
try:
|
||||
stream.process(response.token)
|
||||
except HarmonyError:
|
||||
logger.error("Encountered critical Harmony Error, returning early")
|
||||
return
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
@@ -103,7 +103,7 @@ class RunnerSupervisor:
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
@@ -191,7 +191,7 @@ class RunnerSupervisor:
|
||||
logger.info("Checking runner's status")
|
||||
if self.runner_process.is_alive():
|
||||
logger.info("Runner was found to be alive, attempting to join process")
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
await to_thread.run_sync(self.runner_process.join, 1)
|
||||
rc = self.runner_process.exitcode
|
||||
logger.info(f"RunnerSupervisor exited with exit code {rc}")
|
||||
if rc == 0:
|
||||
|
||||
@@ -123,12 +123,7 @@ def run_gpt_oss_pipeline_device(
|
||||
generated_text = ""
|
||||
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -199,8 +194,6 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(None)
|
||||
cache = KVPrefixCache()
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,9 +142,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
@@ -163,11 +161,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
@@ -180,11 +176,9 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns last token(s) — for models with SSM/rotating caches,
|
||||
# snapshot availability constrains how far back we can trim, so remaining
|
||||
# may be 1 or 2 tokens depending on the model.
|
||||
assert len(remaining_tokens) >= 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
@@ -200,10 +194,10 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
@@ -244,11 +238,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -284,11 +276,9 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -311,7 +301,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -328,7 +318,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
@@ -342,7 +331,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -358,7 +347,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -380,7 +368,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -407,7 +395,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
@@ -440,7 +427,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
@@ -461,7 +447,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -476,7 +462,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -489,7 +474,6 @@ class TestKVPrefixCacheWithModel:
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -500,7 +484,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -513,7 +497,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -538,7 +522,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -343,16 +343,8 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
|
||||
def contains(card: ModelCard, x: str):
|
||||
return x in card.model_id.lower()
|
||||
|
||||
glm_model_cards = [
|
||||
card
|
||||
for card in await get_model_cards()
|
||||
if contains(card, "glm")
|
||||
and not contains(card, "-5")
|
||||
and not contains(card, "4.7")
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_gpt_oss
|
||||
|
||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||
# These are stable since they come from the model's vocabulary.
|
||||
_CHANNEL = 200005 # <|channel|>
|
||||
_START = 200006 # <|start|>
|
||||
_MESSAGE = 200008 # <|message|>
|
||||
_CALL = 200012 # <|call|>
|
||||
_END = 200007 # <|end|>
|
||||
_ASSISTANT = 173781 # "assistant"
|
||||
|
||||
# fmt: off
|
||||
# " to=functions.get_current_weather<|channel|>commentary json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_A_TOKENS: list[tuple[int, str]] = [
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>commentary to=functions.get_current_weather json<|message|>{\"location\": \"Tokyo\"}<|call|>"
|
||||
FORMAT_B_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(12606, "comment"),
|
||||
(815, "ary"),
|
||||
(316, " to"),
|
||||
(28, "="),
|
||||
(44580, "functions"),
|
||||
(775, ".get"),
|
||||
(23981, "_current"),
|
||||
(170154, "_weather"),
|
||||
(5701, " json"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(10848, '{"'),
|
||||
(7693, "location"),
|
||||
(1243, '":'),
|
||||
(392, ' "'),
|
||||
(173844, "Tokyo"),
|
||||
(18583, '"}'),
|
||||
(_CALL, "<|call|>"),
|
||||
]
|
||||
|
||||
# "<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ..."
|
||||
# Full analysis-then-tool-call as the model actually generates it.
|
||||
THINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [
|
||||
(_CHANNEL, "<|channel|>"),
|
||||
(35644, "analysis"),
|
||||
(_MESSAGE, "<|message|>"),
|
||||
(12845, "Let"),
|
||||
(668, " me"),
|
||||
(2411, " think"),
|
||||
(1078, " about"),
|
||||
(495, " this"),
|
||||
(13, "."),
|
||||
(_END, "<|end|>"),
|
||||
# Model generates a new message header for the tool call:
|
||||
(_START, "<|start|>"),
|
||||
(_ASSISTANT, "assistant"),
|
||||
*FORMAT_B_TOKENS,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _make_gen_responses(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse]:
|
||||
"""Build GenerationResponse list from (token_id, text) pairs."""
|
||||
responses: list[GenerationResponse] = []
|
||||
for i, (tid, text) in enumerate(tokens):
|
||||
is_last = i == len(tokens) - 1
|
||||
responses.append(
|
||||
GenerationResponse(
|
||||
text=text,
|
||||
token=tid,
|
||||
finish_reason="stop" if is_last else None,
|
||||
usage=None,
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect(
|
||||
tokens: list[tuple[int, str]],
|
||||
) -> list[GenerationResponse | ToolCallResponse]:
|
||||
"""Feed tokens through parse_gpt_oss and collect all yielded responses."""
|
||||
|
||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||
yield from _make_gen_responses(tokens)
|
||||
|
||||
return list(parse_gpt_oss(_gen()))
|
||||
|
||||
|
||||
def _get_tool_call(
|
||||
results: list[GenerationResponse | ToolCallResponse],
|
||||
) -> ToolCallResponse:
|
||||
"""Extract the single ToolCallResponse from results."""
|
||||
tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_calls) == 1, f"Expected 1 ToolCallResponse, got {len(tool_calls)}"
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
class TestParseGptOssRecipientPlacement:
|
||||
"""Both Harmony recipient placements must produce identical tool calls."""
|
||||
|
||||
def test_format_a_yields_tool_call(self):
|
||||
results = _collect(FORMAT_A_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_format_b_yields_tool_call(self):
|
||||
results = _collect(FORMAT_B_TOKENS)
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert '"location"' in tc.tool_calls[0].arguments
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
|
||||
def test_both_formats_produce_identical_tool_calls(self):
|
||||
tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))
|
||||
tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))
|
||||
assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name
|
||||
assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments
|
||||
|
||||
|
||||
class TestParseGptOssThinkingThenToolCall:
|
||||
"""Analysis (thinking) followed by a tool call must yield both."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
assert tc.tool_calls[0].name == "get_current_weather"
|
||||
assert "Tokyo" in tc.tool_calls[0].arguments
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
eval_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$eval_runner" | while IFS= read -r model; do
|
||||
echo "running eval for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$eval_runner@$eval_runner" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout" \
|
||||
>>"./bench/$commit/${model//\//--}-eval.json"
|
||||
echo
|
||||
done
|
||||
@@ -1,691 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||||
|
||||
Tests whether models correctly:
|
||||
- Trigger tool calls when appropriate
|
||||
- Return valid JSON arguments matching function schemas
|
||||
- Handle multi-turn tool use (call -> result -> final answer)
|
||||
- Avoid calling tools when unnecessary
|
||||
|
||||
Start exo with a model first, then run:
|
||||
uv run python tool_call_eval.py --model <model-id>
|
||||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature unit",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
CALCULATOR_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Evaluate a mathematical expression and return the numeric result",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The math expression to evaluate, e.g. '2 + 3 * 4'",
|
||||
},
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_products",
|
||||
"description": "Search for products in a catalog by query, category, and price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["electronics", "clothing", "food", "books"],
|
||||
"description": "Product category to filter by",
|
||||
},
|
||||
"max_price": {
|
||||
"type": "number",
|
||||
"description": "Maximum price in USD",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scenario:
|
||||
name: str
|
||||
description: str
|
||||
messages: list[dict[str, object]]
|
||||
tools: list[dict[str, object]]
|
||||
expect_tool_call: bool
|
||||
expected_function: str | None = None
|
||||
required_arg_keys: list[str] | None = None
|
||||
# For multi-turn: fake tool result to inject, then verify the follow-up.
|
||||
tool_result: str | None = None
|
||||
|
||||
|
||||
SCENARIOS = [
|
||||
# -- Should call a tool --------------------------------------------------
|
||||
Scenario(
|
||||
name="weather_simple",
|
||||
description="Basic weather query -> get_current_weather",
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather like in Tokyo right now?"}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
Scenario(
|
||||
name="calculator_simple",
|
||||
description="Math question -> calculate",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Use the calculator to compute 3847 * 926 + 17293",
|
||||
}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
),
|
||||
Scenario(
|
||||
name="search_with_filters",
|
||||
description="Product search with category and price filter",
|
||||
messages=[{"role": "user", "content": "Find me electronics under $50"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="search_products",
|
||||
required_arg_keys=["query"],
|
||||
),
|
||||
# -- Multi-turn: tool call then follow-up --------------------------------
|
||||
Scenario(
|
||||
name="weather_multi_turn",
|
||||
description="Weather query -> tool result -> natural language summary",
|
||||
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
tool_result=json.dumps(
|
||||
{
|
||||
"temperature": "18C",
|
||||
"condition": "partly cloudy",
|
||||
"humidity": "65%",
|
||||
"wind": "12 km/h NW",
|
||||
}
|
||||
),
|
||||
),
|
||||
Scenario(
|
||||
name="calculator_multi_turn",
|
||||
description="Math query -> tool result -> model reports the answer",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Use the calculator to compute 1847 * 263 + 5921",
|
||||
}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
tool_result=json.dumps({"result": 491682}),
|
||||
),
|
||||
Scenario(
|
||||
name="search_multi_turn",
|
||||
description="Search query -> tool result -> model summarizes products",
|
||||
messages=[
|
||||
{"role": "user", "content": "Search for books about machine learning"}
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="search_products",
|
||||
required_arg_keys=["query"],
|
||||
tool_result=json.dumps(
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"name": "Hands-On Machine Learning",
|
||||
"price": 45.99,
|
||||
"rating": 4.8,
|
||||
},
|
||||
{
|
||||
"name": "Deep Learning with Python",
|
||||
"price": 39.99,
|
||||
"rating": 4.6,
|
||||
},
|
||||
]
|
||||
}
|
||||
),
|
||||
),
|
||||
# -- Sequential tool calls: thinking + tool call, NO final answer ----------
|
||||
# This is the critical scenario for the Harmony recipient placement fix.
|
||||
#
|
||||
# When an assistant message has both thinking content and a tool_call,
|
||||
# AND there is no subsequent final-answer assistant message, the Jinja
|
||||
# template renders BOTH the analysis and the tool call:
|
||||
#
|
||||
# <|start|>assistant<|channel|>analysis<|message|>thinking...<|end|>
|
||||
# <|start|>assistant to=functions.X<|channel|>commentary json<|message|>...<|call|>
|
||||
#
|
||||
# The two consecutive assistant messages have INCONSISTENT start patterns
|
||||
# (one has <|channel|> immediately, the other has to= first).
|
||||
# This confuses the model when it needs to generate its own tool call.
|
||||
#
|
||||
# The reformat fix makes both start with <|start|>assistant<|channel|>,
|
||||
# only differing in the channel name (analysis vs commentary).
|
||||
Scenario(
|
||||
name="chained_tool_calls_same",
|
||||
description="Thinking + weather(Tokyo) -> result -> model must call weather(London)",
|
||||
messages=[
|
||||
{"role": "user", "content": "Compare the weather in Tokyo and London."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check both cities. Let me start with Tokyo.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Tokyo"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
Scenario(
|
||||
name="chained_tool_calls_different",
|
||||
description="Thinking + weather(Berlin) -> result -> model must call calculator",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll handle both. Let me check Berlin's weather first.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Berlin"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_2",
|
||||
"content": json.dumps({"temperature": "12C", "condition": "rainy"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="calculate",
|
||||
required_arg_keys=["expression"],
|
||||
),
|
||||
Scenario(
|
||||
name="chained_tool_calls_three",
|
||||
description="Two prior thinking+tool calls -> results -> model must make a third",
|
||||
messages=[
|
||||
{"role": "user", "content": "Compare weather in Tokyo, Paris, and London."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check all three cities. Starting with Tokyo.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Tokyo"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_3",
|
||||
"content": json.dumps({"temperature": "25C", "condition": "sunny"}),
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Got Tokyo. Now checking Paris.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_4",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": json.dumps({"location": "Paris"}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_4",
|
||||
"content": json.dumps({"temperature": "18C", "condition": "cloudy"}),
|
||||
},
|
||||
],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=True,
|
||||
expected_function="get_current_weather",
|
||||
required_arg_keys=["location"],
|
||||
),
|
||||
# -- Should NOT call a tool ----------------------------------------------
|
||||
Scenario(
|
||||
name="no_tool_joke",
|
||||
description="Joke request should NOT trigger any tool",
|
||||
messages=[{"role": "user", "content": "Tell me a funny joke about cats."}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=False,
|
||||
),
|
||||
Scenario(
|
||||
name="no_tool_factual",
|
||||
description="Factual question answerable from training data",
|
||||
messages=[{"role": "user", "content": "What is the capital of Japan?"}],
|
||||
tools=ALL_TOOLS,
|
||||
expect_tool_call=False,
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
phase: str # "tool_call" or "follow_up"
|
||||
passed: bool
|
||||
checks: dict[str, bool] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
latency_ms: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str | None]:
|
||||
"""Parse JSON arguments and check required keys exist."""
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
return False, f"Invalid JSON: {e}"
|
||||
if not isinstance(args, dict):
|
||||
return False, f"Expected dict, got {type(args).__name__}"
|
||||
missing = [k for k in required_keys if k not in args]
|
||||
if missing:
|
||||
return False, f"Missing keys: {missing}"
|
||||
return True, None
|
||||
|
||||
|
||||
def call_api(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
model: str,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]],
|
||||
timeout: float,
|
||||
) -> tuple[dict[str, object], float]:
|
||||
"""POST to /chat/completions, return (response_json, latency_ms)."""
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
body: dict[str, object] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
t0 = time.monotonic()
|
||||
resp = client.post(url, json=body, timeout=timeout)
|
||||
latency = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
return resp.json(), latency
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_scenario(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
model: str,
|
||||
scenario: Scenario,
|
||||
timeout: float,
|
||||
verbose: bool,
|
||||
) -> list[ScenarioResult]:
|
||||
results: list[ScenarioResult] = []
|
||||
|
||||
# --- Phase 1: initial request ---
|
||||
try:
|
||||
data, latency = call_api(
|
||||
client, base_url, model, scenario.messages, scenario.tools, timeout
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="tool_call",
|
||||
passed=False,
|
||||
error=f"API error: {e}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if verbose:
|
||||
print(f" response: {json.dumps(data, indent=2)}")
|
||||
|
||||
choice = data["choices"][0]
|
||||
finish_reason = choice.get("finish_reason")
|
||||
message = choice.get("message", {})
|
||||
tool_calls = message.get("tool_calls")
|
||||
content = message.get("content")
|
||||
|
||||
checks: dict[str, bool] = {}
|
||||
|
||||
if scenario.expect_tool_call:
|
||||
checks["finish_reason_tool_calls"] = finish_reason == "tool_calls"
|
||||
checks["has_tool_call"] = isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||
|
||||
args_err: str | None = None
|
||||
if checks["has_tool_call"]:
|
||||
tc = tool_calls[0]
|
||||
fn = tc.get("function", {})
|
||||
checks["correct_function"] = (
|
||||
scenario.expected_function is None
|
||||
or fn.get("name") == scenario.expected_function
|
||||
)
|
||||
if scenario.required_arg_keys:
|
||||
ok, args_err = validate_args(
|
||||
fn.get("arguments", ""), scenario.required_arg_keys
|
||||
)
|
||||
checks["valid_arguments"] = ok
|
||||
else:
|
||||
checks["valid_arguments"] = True
|
||||
else:
|
||||
checks["correct_function"] = False
|
||||
checks["valid_arguments"] = False
|
||||
args_err = "No tool call returned"
|
||||
|
||||
passed = all(checks.values())
|
||||
error = args_err if not passed else None
|
||||
else:
|
||||
checks["finish_reason_stop"] = finish_reason == "stop"
|
||||
checks["no_tool_call"] = tool_calls is None or len(tool_calls) == 0
|
||||
checks["has_content"] = isinstance(content, str) and len(content.strip()) > 0
|
||||
passed = all(checks.values())
|
||||
error = (
|
||||
None
|
||||
if passed
|
||||
else (
|
||||
f"finish_reason={finish_reason}, "
|
||||
f"tool_calls={'yes' if tool_calls else 'no'}, "
|
||||
f"content={'yes' if content else 'no'}"
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="tool_call",
|
||||
passed=passed,
|
||||
checks=checks,
|
||||
error=error,
|
||||
latency_ms=latency,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Phase 2: multi-turn follow-up ---
|
||||
if scenario.tool_result is not None and checks.get("has_tool_call"):
|
||||
tc = tool_calls[0]
|
||||
fn = tc.get("function", {})
|
||||
follow_up_messages: list[dict[str, object]] = list(scenario.messages) + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.get("id", "call_0"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fn.get("name", ""),
|
||||
"arguments": fn.get("arguments", "{}"),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.get("id", "call_0"),
|
||||
"content": scenario.tool_result,
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
data2, latency2 = call_api(
|
||||
client,
|
||||
base_url,
|
||||
model,
|
||||
follow_up_messages,
|
||||
scenario.tools,
|
||||
timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="follow_up",
|
||||
passed=False,
|
||||
error=f"API error: {e}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if verbose:
|
||||
print(f" follow_up response: {json.dumps(data2, indent=2)}")
|
||||
|
||||
choice2 = data2["choices"][0]
|
||||
message2 = choice2.get("message", {})
|
||||
checks2: dict[str, bool] = {}
|
||||
checks2["finish_reason_stop"] = choice2.get("finish_reason") == "stop"
|
||||
tc2 = message2.get("tool_calls")
|
||||
checks2["no_tool_call"] = tc2 is None or len(tc2) == 0
|
||||
c2 = message2.get("content")
|
||||
checks2["has_content"] = isinstance(c2, str) and len(c2.strip()) > 0
|
||||
|
||||
passed2 = all(checks2.values())
|
||||
error2 = None
|
||||
if not passed2:
|
||||
error2 = (
|
||||
f"finish_reason={choice2.get('finish_reason')}, "
|
||||
f"tool_calls={'yes' if tc2 else 'no'}, "
|
||||
f"content={'yes' if c2 else 'no'}"
|
||||
)
|
||||
results.append(
|
||||
ScenarioResult(
|
||||
name=scenario.name,
|
||||
phase="follow_up",
|
||||
passed=passed2,
|
||||
checks=checks2,
|
||||
error=error2,
|
||||
latency_ms=latency2,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Tool-calling eval for exo")
|
||||
parser.add_argument("--model", required=True, help="Model ID to test")
|
||||
parser.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout", type=float, default=120, help="Per-request timeout (seconds)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repeat each scenario N times"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scenarios", nargs="*", help="Run only these scenarios (by name)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print full API responses"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
scenarios = SCENARIOS
|
||||
if args.scenarios:
|
||||
scenarios = [s for s in SCENARIOS if s.name in args.scenarios]
|
||||
if not scenarios:
|
||||
print(f"No matching scenarios. Available: {[s.name for s in SCENARIOS]}")
|
||||
sys.exit(1)
|
||||
|
||||
base_url = f"http://{args.host}:{args.port}/v1"
|
||||
total_runs = len(scenarios) * args.repeat
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Endpoint: {base_url}")
|
||||
print(f"Scenarios: {len(scenarios)} x {args.repeat} = {total_runs} runs")
|
||||
print("=" * 64)
|
||||
|
||||
all_results: list[ScenarioResult] = []
|
||||
|
||||
with httpx.Client() as client:
|
||||
for run_idx in range(args.repeat):
|
||||
if args.repeat > 1:
|
||||
print(f"\n--- Run {run_idx + 1}/{args.repeat} ---")
|
||||
|
||||
for scenario in scenarios:
|
||||
print(f"\n {scenario.name}: {scenario.description}")
|
||||
|
||||
results = run_scenario(
|
||||
client,
|
||||
base_url,
|
||||
args.model,
|
||||
scenario,
|
||||
args.timeout,
|
||||
args.verbose,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
for r in results:
|
||||
status = "PASS" if r.passed else "FAIL"
|
||||
print(f" [{r.phase:>10}] {status} ({r.latency_ms:.0f}ms)")
|
||||
for check_name, check_ok in r.checks.items():
|
||||
mark = "+" if check_ok else "-"
|
||||
print(f" {mark} {check_name}")
|
||||
if r.error:
|
||||
print(f" ! {r.error}")
|
||||
|
||||
# --- Summary ---
|
||||
print(f"\n{'=' * 64}")
|
||||
|
||||
total = len(all_results)
|
||||
passed = sum(1 for r in all_results if r.passed)
|
||||
|
||||
tool_call_results = [r for r in all_results if r.phase == "tool_call"]
|
||||
follow_up_results = [r for r in all_results if r.phase == "follow_up"]
|
||||
tc_passed = sum(1 for r in tool_call_results if r.passed)
|
||||
fu_passed = sum(1 for r in follow_up_results if r.passed)
|
||||
avg_latency = sum(r.latency_ms for r in all_results) / total if total else 0
|
||||
|
||||
print(f"Total: {passed}/{total} passed ({100 * passed / total:.0f}%)")
|
||||
print(f"Tool call: {tc_passed}/{len(tool_call_results)} passed")
|
||||
if follow_up_results:
|
||||
print(f"Follow-up: {fu_passed}/{len(follow_up_results)} passed")
|
||||
print(f"Avg latency: {avg_latency:.0f}ms")
|
||||
|
||||
if passed < total:
|
||||
print("\nFailed:")
|
||||
for r in all_results:
|
||||
if not r.passed:
|
||||
print(f" - {r.name} [{r.phase}]: {r.error}")
|
||||
|
||||
sys.exit(0 if passed == total else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
50
uv.lock
generated
50
uv.lock
generated
@@ -377,8 +377,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -416,9 +416,9 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.7" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
@@ -447,7 +447,6 @@ name = "exo-bench"
|
||||
version = "0.1.0"
|
||||
source = { editable = "bench" }
|
||||
dependencies = [
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -457,7 +456,6 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
@@ -1022,8 +1020,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1050,12 +1048,18 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
|
||||
]
|
||||
@@ -1068,14 +1072,6 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260218+14841977"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.6"
|
||||
@@ -1102,20 +1098,30 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.7"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/0d/56542e2ae13ec6f542d3977d7cff89a205d4f6c5122e0ce23f33265f61c9/mlx_lm-0.30.7.tar.gz", hash = "sha256:e5f31ac58d9f2381f28e1ba639ff903e64f7cff1bdc245c0bc97f72264be329c", size = 275764, upload-time = "2026-02-12T18:41:11.86Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/17/a41c798a3d9cbdc47f39c6db5bba4c2cd199203ead26bf911cb03b644070/mlx_lm-0.30.7-py3-none-any.whl", hash = "sha256:17442a4bf01c4c2d3bca1e647712fe44f19890c3f1eadc8589d389e57b44b9bf", size = 386591, upload-time = "2026-02-12T18:41:10.236Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user