mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
7 Commits
session-id
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
229b21fa8c | ||
|
|
00175bb630 | ||
|
|
3ce7220713 | ||
|
|
9bf157d527 | ||
|
|
69f22fbc89 | ||
|
|
b394a1b665 | ||
|
|
33bacfa7d8 |
27
.github/workflows/pipeline.yml
vendored
27
.github/workflows/pipeline.yml
vendored
@@ -8,6 +8,33 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
123
Cargo.lock
generated
123
Cargo.lock
generated
@@ -141,6 +141,12 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
@@ -298,6 +304,19 @@ version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bimap"
|
||||
version = "0.6.3"
|
||||
@@ -497,6 +516,15 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -718,6 +746,29 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn 2.0.111",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"impl-trait-for-tuples",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"thiserror 2.0.17",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
@@ -1584,6 +1640,17 @@ dependencies = [
|
||||
"xmltree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "impl-trait-for-tuples"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
@@ -1762,6 +1829,12 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "libp2p"
|
||||
version = "0.56.0"
|
||||
@@ -2751,13 +2824,16 @@ name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
@@ -2842,6 +2918,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -3192,14 +3279,28 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"either",
|
||||
"hashbrown 0.16.1",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"memoffset",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"rust_decimal",
|
||||
"smallvec",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
@@ -3640,6 +3741,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
@@ -4504,12 +4615,24 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "1.3.0"
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -26,21 +26,49 @@ opt-level = 3
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
11
README.md
11
README.md
@@ -72,23 +72,16 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
status = .stopped
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Canary benchmark manifest
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
@@ -288,151 +288,6 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
@@ -680,11 +535,6 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -719,16 +569,13 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
if not selected and args.settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
@@ -760,16 +607,6 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
# Single-node M3 Ultra benchmarks
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
type = "none"
|
||||
|
||||
# Default args merged into each benchmark's args (benchmark-level args win).
|
||||
[defaults]
|
||||
pp = [512, 2048, 8192, 16384]
|
||||
tg = 128
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-3bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-6bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
@@ -103,7 +103,7 @@
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking_toggle") && caps.includes("text");
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
@@ -265,7 +265,6 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -290,11 +289,7 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
} else if (isImageModel() && content) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
|
||||
@@ -225,7 +225,6 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -256,7 +255,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message, i (message.id)}
|
||||
{#each messageList as message (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -318,11 +317,9 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -754,13 +751,8 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
@@ -59,14 +59,13 @@
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -177,90 +176,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size.toUpperCase()}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -310,7 +311,7 @@
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition()
|
||||
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
|
||||
@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "auto"
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1536"
|
||||
| "1536x1024";
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -337,7 +336,7 @@ export interface EditingImage {
|
||||
}
|
||||
|
||||
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "auto",
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -115,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260218+14841977"; in
|
||||
version = let v = "0.30.6"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
|
||||
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.7",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -64,7 +64,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
@@ -133,7 +132,7 @@ markers = [
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
addopts = "-m 'not slow'"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -14,9 +14,7 @@
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
@@ -24,12 +22,6 @@
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
passthru = prev.exo-pyo3-bindings.passthru or { };
|
||||
postInstall = ''
|
||||
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
|
||||
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
|
||||
touch $siteDir/py.typed
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -37,47 +29,17 @@
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
@@ -86,28 +48,16 @@
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
@@ -168,21 +118,6 @@
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
testVenv
|
||||
pkgs.basedpyright
|
||||
];
|
||||
}
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "4bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "8bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "5bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2.5"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "3bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
|
||||
@@ -25,17 +25,17 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
@@ -45,6 +45,8 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
@@ -52,11 +54,24 @@ tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
#tracing-subscriber = "0.3"
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -6,7 +6,7 @@ use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
pin::{Pin, pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
@@ -33,6 +33,8 @@ where
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
@@ -24,6 +25,7 @@ use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -34,10 +36,14 @@ pub(crate) mod r#const {
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
@@ -45,6 +51,7 @@ pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
@@ -55,7 +62,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +98,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -168,6 +175,24 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
|
||||
@@ -11,9 +11,9 @@ use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt a
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
@@ -155,6 +155,7 @@ async fn networking_task(
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
@@ -484,7 +485,7 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
|
||||
@@ -19,6 +19,8 @@ either = { workspace = true }
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
@@ -27,6 +29,11 @@ futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
@@ -34,4 +41,4 @@ keccak-const = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -24,8 +24,8 @@ use libp2p::{
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,19 @@
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
@@ -42,3 +54,11 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -13,7 +14,6 @@ from exo.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
@@ -21,7 +21,7 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
@@ -45,8 +45,7 @@ class DownloadCoordinator:
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
offline: bool = False
|
||||
_system_id: SystemId = field(default_factory=SystemId)
|
||||
event_index_counter: Iterator[int]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
@@ -62,13 +61,8 @@ class DownloadCoordinator:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
return str(EXO_MODELS_DIR / model_id.normalize())
|
||||
|
||||
async def _download_progress_callback(
|
||||
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
@@ -80,7 +74,6 @@ class DownloadCoordinator:
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -100,7 +93,6 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
@@ -109,17 +101,13 @@ class DownloadCoordinator:
|
||||
self._last_progress_time[model_id] = current_time()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info(
|
||||
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
|
||||
)
|
||||
if not self.offline:
|
||||
self._test_internet_connection()
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
@@ -182,11 +170,7 @@ class DownloadCoordinator:
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
@@ -200,7 +184,6 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -208,20 +191,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
return
|
||||
|
||||
if self.offline:
|
||||
logger.warning(
|
||||
f"Offline mode: model {model_id} is not fully available locally, cannot download"
|
||||
)
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=f"Model files not found locally in offline mode: {model_id}",
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
@@ -237,7 +206,6 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
@@ -251,7 +219,6 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
@@ -286,7 +253,6 @@ class DownloadCoordinator:
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
@@ -294,16 +260,15 @@ class DownloadCoordinator:
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self._system_id,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
@@ -330,18 +295,11 @@ class DownloadCoordinator:
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
@@ -350,9 +308,6 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -448,13 +448,12 @@ async def download_file_with_retry(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress, skip_internet
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
@@ -488,14 +487,10 @@ async def _download_file(
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
if skip_internet:
|
||||
return target_path
|
||||
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
@@ -515,11 +510,6 @@ async def _download_file(
|
||||
)
|
||||
return target_path
|
||||
|
||||
if skip_internet:
|
||||
raise FileNotFoundError(
|
||||
f"File {path} not found locally and cannot download in offline mode"
|
||||
)
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -824,7 +814,6 @@ async def download_shard(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
on_connection_lost=on_connection_lost,
|
||||
skip_internet=skip_internet,
|
||||
)
|
||||
|
||||
if not skip_download:
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
"""Tests for offline/air-gapped mode."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
download_file_with_retry,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.worker.downloads import FileListEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestDownloadFileOffline:
|
||||
"""Tests for _download_file with skip_internet=True."""
|
||||
|
||||
async def test_returns_local_file_without_http_verification(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists locally, return it immediately
|
||||
without making any HTTP calls (no file_meta verification)."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"model weights data")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_raises_file_not_found_for_missing_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file does NOT exist locally,
|
||||
raise FileNotFoundError instead of attempting download."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="offline mode"):
|
||||
await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"missing_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
async def test_returns_local_file_in_subdirectory(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists in a subdirectory,
|
||||
return it without HTTP calls."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
subdir = target_dir / "transformer"
|
||||
await aios.makedirs(subdir, exist_ok=True)
|
||||
|
||||
local_file = subdir / "diffusion_pytorch_model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"weights")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"transformer/diffusion_pytorch_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
|
||||
class TestDownloadFileWithRetryOffline:
|
||||
"""Tests for download_file_with_retry with skip_internet=True."""
|
||||
|
||||
async def test_propagates_skip_internet_to_download_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify skip_internet is passed through to _download_file."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "config.json"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b'{"model_type": "qwen2"}')
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_file_not_found_does_not_retry(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""FileNotFoundError from offline mode should not trigger retries."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"nonexistent.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
|
||||
class TestFetchFileListOffline:
|
||||
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
|
||||
|
||||
async def test_uses_cached_file_list(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and cache file exists, use it without network."""
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=200),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
assert result == cached_list
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
async def test_falls_back_to_local_directory_scan(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and no cache but local files exist,
|
||||
build file list from local directory."""
|
||||
import json
|
||||
|
||||
model_dir = temp_models_dir / model_id.normalize()
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(model_dir / "config.json", "w") as f:
|
||||
await f.write('{"model_type": "qwen2"}')
|
||||
|
||||
index_data = {
|
||||
"metadata": {},
|
||||
"weight_map": {"model.layers.0.weight": "model.safetensors"},
|
||||
}
|
||||
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
|
||||
await f.write(json.dumps(index_data))
|
||||
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
|
||||
await f.write(b"x" * 500)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
mock_fetch.assert_not_called()
|
||||
paths = {entry.path for entry in result}
|
||||
assert "config.json" in paths
|
||||
assert "model.safetensors" in paths
|
||||
|
||||
async def test_raises_when_no_cache_and_no_local_files(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and neither cache nor local files exist,
|
||||
raise FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="No internet"):
|
||||
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
|
||||
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Iterator, Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -37,11 +38,11 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
offline: bool
|
||||
event_index_counter: Iterator[int]
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> Self:
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
@@ -55,6 +56,9 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
@@ -63,7 +67,7 @@ class Node:
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
offline=args.offline,
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -89,6 +93,7 @@ class Node:
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -126,13 +131,11 @@ class Node:
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
args.offline,
|
||||
event_index_counter,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -144,6 +147,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -205,6 +210,7 @@ class Node:
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
@@ -215,7 +221,7 @@ class Node:
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
offline=self.offline,
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -232,6 +238,7 @@ class Node:
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
@@ -253,9 +260,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
@@ -278,7 +282,6 @@ class Args(CamelCaseModel):
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -326,11 +329,6 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -17,7 +17,6 @@ from exo.shared.types.api import (
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -126,8 +125,6 @@ async def generate_chat_stream(
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
@@ -141,8 +138,6 @@ async def generate_chat_stream(
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
@@ -166,15 +161,12 @@ async def generate_chat_stream(
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -184,9 +176,7 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -194,7 +184,6 @@ async def collect_chat_response(
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -204,8 +193,6 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
@@ -236,7 +223,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -254,6 +241,4 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
@@ -161,14 +161,12 @@ async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -176,8 +174,6 @@ async def collect_claude_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
@@ -187,10 +183,12 @@ async def collect_claude_response(
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -210,11 +208,11 @@ async def collect_claude_response(
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data if available
|
||||
input_tokens = last_usage.prompt_tokens if last_usage else 0
|
||||
output_tokens = last_usage.completion_tokens if last_usage else 0
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
|
||||
yield ClaudeMessagesResponse(
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
@@ -223,8 +221,7 @@ async def collect_claude_response(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
).model_dump_json()
|
||||
return
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
@@ -252,7 +249,7 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
last_stats = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -260,9 +257,8 @@ async def generate_claude_stream(
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
@@ -294,6 +290,7 @@ async def generate_claude_stream(
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
@@ -305,9 +302,9 @@ async def generate_claude_stream(
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from usage if available
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
@@ -122,15 +121,13 @@ async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
) -> ResponsesResponse:
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -138,32 +135,32 @@ async def collect_responses_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=tool.id,
|
||||
call_id=tool.id,
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from usage data if available
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_usage is not None:
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
@@ -175,15 +172,14 @@ async def collect_responses_response(
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
yield ResponsesResponse(
|
||||
return ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
@@ -239,16 +235,15 @@ async def generate_responses_stream(
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
last_stats = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
@@ -307,6 +302,7 @@ async def generate_responses_stream(
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
@@ -350,13 +346,13 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from usage data if available
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_usage is not None:
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
|
||||
@@ -85,7 +85,6 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
ImageSize,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -101,7 +100,6 @@ from exo.shared.types.api import (
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
normalize_image_size,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -127,11 +125,10 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -183,7 +180,6 @@ class API:
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self._system_id = SystemId()
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
@@ -234,7 +230,6 @@ class API:
|
||||
self._event_log.close()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self.state = State()
|
||||
self._system_id = SystemId()
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._text_generation_queues = {}
|
||||
@@ -545,14 +540,16 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -647,14 +644,11 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -670,7 +664,8 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -755,11 +750,9 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -890,11 +883,6 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -976,11 +964,6 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1015,13 +998,12 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"stream": False,
|
||||
"partial_images": 0,
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1042,7 +1024,7 @@ class API:
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: ImageSize,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
@@ -1112,7 +1094,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str | None = Form(None),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
@@ -1138,7 +1120,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=normalize_image_size(size),
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream_bool,
|
||||
@@ -1174,7 +1156,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str | None = Form(None),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
@@ -1194,7 +1176,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=normalize_image_size(size),
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
@@ -1239,15 +1221,12 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1275,15 +1254,11 @@ class API:
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -1410,7 +1385,7 @@ class API:
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
@@ -1474,12 +1449,12 @@ class API:
|
||||
while self.paused:
|
||||
await self.paused_ev.wait()
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self._system_id, command=command)
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
|
||||
@@ -24,12 +24,11 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
@@ -40,7 +39,6 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -90,8 +88,7 @@ class Master:
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._system_id = SystemId()
|
||||
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
@@ -282,14 +279,14 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self._system_id, command=cmd
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
@@ -302,7 +299,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -312,7 +309,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -322,18 +319,6 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -342,9 +327,10 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
@@ -416,7 +402,7 @@ class Master:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=self._system_id,
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
@@ -429,7 +415,7 @@ class Master:
|
||||
# Convenience method since this line is ugly
|
||||
await self.global_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=self._system_id,
|
||||
origin=self.node_id,
|
||||
origin_idx=event.idx,
|
||||
session=self.session_id,
|
||||
event=event.event,
|
||||
|
||||
@@ -22,15 +22,9 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
@@ -192,7 +186,6 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -208,18 +201,6 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -4,11 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
ClaudeMessagesResponse,
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
@@ -21,18 +17,6 @@ async def _chunks_to_stream(
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _collect_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Helper to consume the async generator and parse the JSON response."""
|
||||
parts: list[str] = []
|
||||
async for part in collect_claude_response(command_id, model, chunk_stream):
|
||||
parts.append(part)
|
||||
return ClaudeMessagesResponse.model_validate_json("".join(parts))
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
@@ -63,7 +47,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await _collect_response(
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -93,7 +77,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await _collect_response(
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -118,7 +102,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await _collect_response(
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -132,7 +116,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await _collect_response(
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
|
||||
@@ -15,7 +15,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
@@ -75,12 +75,13 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=SystemId("Worker"),
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
@@ -107,7 +108,7 @@ async def test_master():
|
||||
logger.info("inject a CreateInstance Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=SystemId("API"),
|
||||
origin=node_id,
|
||||
command=(
|
||||
PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
@@ -132,7 +133,7 @@ async def test_master():
|
||||
logger.info("inject a TextGeneration Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=SystemId("API"),
|
||||
origin=node_id,
|
||||
command=(
|
||||
TextGeneration(
|
||||
command_id=CommandId(),
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -218,6 +218,11 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
@@ -258,6 +263,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
|
||||
@@ -44,8 +44,7 @@ async def _refresh_card_cache():
|
||||
async for toml_file in path.rglob("*.toml"):
|
||||
try:
|
||||
card = await ModelCard.load_from_path(toml_file)
|
||||
if card.model_id not in _card_cache:
|
||||
_card_cache[card.model_id] = card
|
||||
_card_cache[card.model_id] = card
|
||||
except (ValidationError, TOMLKitError):
|
||||
pass
|
||||
|
||||
@@ -183,7 +182,6 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
|
||||
@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# ======= #
|
||||
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
|
||||
# Pump local commands so our commands_seen is high before the round starts
|
||||
for _ in range(50):
|
||||
await co_tx.send(
|
||||
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
|
||||
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
|
||||
)
|
||||
|
||||
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal, get_args
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -227,6 +228,13 @@ class PlaceInstanceParams(BaseModel):
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
instance: Instance
|
||||
@@ -262,27 +270,6 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
ImageSize = Literal[
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
]
|
||||
|
||||
|
||||
def normalize_image_size(v: object) -> ImageSize:
|
||||
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
|
||||
if v is None:
|
||||
return "auto"
|
||||
if v not in get_args(ImageSize):
|
||||
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
|
||||
return v # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
@@ -302,7 +289,7 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: ImageSize = "auto"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
@@ -310,11 +297,6 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
@@ -331,18 +313,13 @@ class ImageEditsTaskParams(BaseModel):
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: ImageSize = "auto"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
if name == "image_data":
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId, SystemId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -93,17 +89,16 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: SystemId
|
||||
origin: NodeId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: SystemId
|
||||
origin: NodeId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -25,10 +25,6 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class SystemId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -166,6 +166,6 @@ class ForwarderEvent(CamelCaseModel):
|
||||
"""An event the forwarder will serialize and send over the network"""
|
||||
|
||||
origin_idx: int = Field(ge=0)
|
||||
origin: SystemId
|
||||
origin: NodeId
|
||||
session: SessionId
|
||||
event: Event
|
||||
|
||||
@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -61,11 +60,6 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -93,7 +87,6 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -26,7 +26,6 @@ class DownloadProgressData(CamelCaseModel):
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
model_directory: str = ""
|
||||
|
||||
|
||||
class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
@@ -62,7 +62,6 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import sys
|
||||
|
||||
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
@@ -29,4 +27,4 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
"""
|
||||
|
||||
print(banner, file=sys.stderr)
|
||||
print(banner)
|
||||
|
||||
@@ -125,9 +125,7 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -14,7 +14,6 @@ from exo.shared.types.api import (
|
||||
ImageEditsTaskParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ImageSize,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
@@ -24,9 +23,9 @@ from exo.shared.types.worker.runner_response import (
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: ImageSize) -> tuple[int, int]:
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if size_str == "auto":
|
||||
if not size_str:
|
||||
return (1024, 1024)
|
||||
|
||||
try:
|
||||
@@ -110,9 +109,6 @@ def generate_image(
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
if task.size == "auto":
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
|
||||
@@ -163,14 +163,11 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
|
||||
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
if self.is_prefill:
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(_cache.keys) # type: ignore
|
||||
mx.eval(cache.keys) # type: ignore
|
||||
|
||||
if not self.is_prefill:
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
@@ -310,9 +307,7 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
last = cache[-1] # type: ignore
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
|
||||
@@ -338,9 +333,7 @@ def patch_tensor_model[T](model: T) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
last = cache[-1] # pyright: ignore[reportAny]
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
@@ -554,12 +547,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
@@ -590,18 +581,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -794,7 +779,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
@@ -907,7 +893,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE.
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
@@ -15,3 +15,8 @@ DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
|
||||
# MTP enables speculative decoding using the model's built-in draft layer
|
||||
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
|
||||
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generator, cast, get_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
@@ -38,6 +41,8 @@ from exo.worker.engines.mlx.constants import (
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
MTP_ENABLED,
|
||||
MTP_NUM_DRAFT_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -57,7 +62,6 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -87,9 +91,6 @@ def prefill(
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
@@ -196,6 +197,11 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def _has_mtp_module(model: Model) -> bool:
|
||||
"""Check if the model has an attached MTP module."""
|
||||
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -301,6 +307,24 @@ def mlx_generate(
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
|
||||
# Check if we should use MTP speculative decoding
|
||||
use_mtp = MTP_ENABLED and _has_mtp_module(model)
|
||||
|
||||
if use_mtp:
|
||||
logger.info("Using MTP speculative decoding")
|
||||
yield from _mlx_generate_with_mtp(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
return
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
@@ -309,16 +333,22 @@ def mlx_generate(
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
@@ -328,7 +358,6 @@ def mlx_generate(
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
|
||||
logger.info("Starting decode")
|
||||
mx_barrier(group)
|
||||
|
||||
for completion_tokens, out in enumerate(
|
||||
@@ -391,11 +420,10 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
total_prompt_tokens = len(all_prompt_tokens)
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_prompt_tokens + completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
@@ -469,3 +497,66 @@ def mlx_generate(
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
|
||||
def _mlx_generate_with_mtp(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: KVCacheType,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""MTP speculative decoding generation path.
|
||||
|
||||
Uses the model's attached MTP module for speculative decoding,
|
||||
which can provide 1.5-2x speedup with ~81% acceptance rate.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
mtp_module = cast("MTPModule", model.mtp_module)
|
||||
|
||||
for out in mtp_speculative_generate(
|
||||
model=model,
|
||||
mtp_module=mtp_module,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=cast(list[Any], prompt_cache),
|
||||
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(f"{out.text} (from_draft={out.from_draft})")
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
__all__ = ["MTPModule", "mtp_speculative_generate"]
|
||||
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
|
||||
|
||||
The MTP architecture predicts one additional token ahead using:
|
||||
1. hnorm - RMSNorm for hidden state normalization
|
||||
2. enorm - RMSNorm for embedding normalization
|
||||
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
|
||||
4. transformer_block - Single decoder layer (attention + MLP)
|
||||
5. Shared embedding/lm_head from main model
|
||||
|
||||
Forward pass:
|
||||
h_norm = hnorm(hidden_state)
|
||||
e_norm = enorm(embed(token))
|
||||
projected = eh_proj(concat([h_norm, e_norm]))
|
||||
new_hidden = transformer_block(projected)
|
||||
logits = lm_head(output_norm(new_hidden))
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
DeepseekV3Attention,
|
||||
DeepseekV3MLP,
|
||||
ModelArgs,
|
||||
)
|
||||
|
||||
MTP_LAYER_INDEX = 61
|
||||
|
||||
|
||||
class MTPModule(nn.Module):
|
||||
"""Multi-Token Prediction module for DeepSeek V3.
|
||||
|
||||
This module is initialized from the layer 61 weights that are normally
|
||||
discarded during model loading. It enables speculative decoding by
|
||||
predicting one token ahead using the hidden state from the main model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
shared_embedding: nn.Embedding,
|
||||
shared_lm_head: nn.Linear,
|
||||
output_norm: nn.RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# MTP-specific normalization layers
|
||||
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Projection: concatenated [hidden, embedding] -> hidden_size
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
# Single transformer block for MTP
|
||||
# Use a dense MLP since this is just a single layer
|
||||
self.transformer_block = MTPTransformerBlock(config)
|
||||
|
||||
# Share embedding and lm_head with main model
|
||||
self._shared_embedding = shared_embedding
|
||||
self._shared_lm_head = shared_lm_head
|
||||
self._output_norm = output_norm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
cache: KVCache | None = None,
|
||||
mask: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass for MTP.
|
||||
|
||||
Args:
|
||||
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
|
||||
draft_token: Token to embed and combine with hidden state [batch, seq_len]
|
||||
cache: Optional KV cache for the MTP transformer block
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
tuple of (logits, new_hidden_state)
|
||||
"""
|
||||
# Get embedding of draft token
|
||||
embedding = self._shared_embedding(draft_token)
|
||||
|
||||
# Normalize hidden state and embedding
|
||||
h_norm = self.hnorm(hidden_state)
|
||||
e_norm = self.enorm(embedding)
|
||||
|
||||
# Project concatenated representation
|
||||
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
|
||||
projected = self.eh_proj(concatenated)
|
||||
|
||||
# Pass through single transformer block
|
||||
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
|
||||
|
||||
# Apply output norm and get logits
|
||||
normed_hidden = self._output_norm(new_hidden)
|
||||
logits = self._shared_lm_head(normed_hidden)
|
||||
|
||||
return logits, new_hidden
|
||||
|
||||
|
||||
class MTPTransformerBlock(nn.Module):
|
||||
"""Single transformer block for MTP.
|
||||
|
||||
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
|
||||
instead of MoE since this is just for the single MTP layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
# MTP uses dense MLP, not MoE
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass with residual connections."""
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
|
||||
"""Extract MTP-specific weights from layer 61.
|
||||
|
||||
The MTP layer has these weight patterns:
|
||||
- model.layers.61.enorm.weight -> MTP embedding normalization
|
||||
- model.layers.61.hnorm.weight -> MTP hidden normalization
|
||||
- model.layers.61.eh_proj.weight -> MTP projection layer
|
||||
- model.layers.61.self_attn.* -> MTP attention
|
||||
- model.layers.61.input_layernorm.* -> MTP layer norms
|
||||
- model.layers.61.post_attention_layernorm.*
|
||||
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
|
||||
|
||||
Args:
|
||||
weights: Full model weights dict
|
||||
|
||||
Returns:
|
||||
Dict of MTP-specific weights with keys renamed for MTPModule
|
||||
"""
|
||||
mtp_weights: dict[str, mx.array] = {}
|
||||
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mtp_prefix):
|
||||
# Remove the layer prefix to get relative path
|
||||
new_key = key[len(mtp_prefix) :]
|
||||
mtp_weights[new_key] = value
|
||||
|
||||
return mtp_weights
|
||||
|
||||
|
||||
def load_mtp_weights_into_module(
|
||||
mtp_module: MTPModule,
|
||||
mtp_weights: dict[str, mx.array],
|
||||
) -> None:
|
||||
"""Load extracted MTP weights into the MTPModule.
|
||||
|
||||
Args:
|
||||
mtp_module: The MTPModule instance to load weights into
|
||||
mtp_weights: Extracted MTP weights from extract_mtp_weights()
|
||||
"""
|
||||
# Map weight names to module attributes
|
||||
weight_mapping: dict[str, str] = {
|
||||
"enorm.weight": "enorm.weight",
|
||||
"hnorm.weight": "hnorm.weight",
|
||||
"eh_proj.weight": "eh_proj.weight",
|
||||
}
|
||||
|
||||
# Load direct mappings
|
||||
for src_name, dst_name in weight_mapping.items():
|
||||
if src_name in mtp_weights:
|
||||
parts = dst_name.split(".")
|
||||
obj: Any = mtp_module
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part) # pyright: ignore[reportAny]
|
||||
setattr(obj, parts[-1], mtp_weights[src_name]) # pyright: ignore[reportAny]
|
||||
|
||||
# Load transformer block weights (self_attn, mlp, layer norms)
|
||||
transformer_prefixes = [
|
||||
"self_attn",
|
||||
"mlp",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
]
|
||||
|
||||
for prefix in transformer_prefixes:
|
||||
for key, value in mtp_weights.items():
|
||||
if key.startswith(prefix):
|
||||
# Navigate to the correct attribute
|
||||
parts = key.split(".")
|
||||
obj = mtp_module.transformer_block
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part) # pyright: ignore[reportAny]
|
||||
setattr(obj, parts[-1], value) # pyright: ignore[reportAny]
|
||||
483
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
483
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""MTP Speculative Decoding for DeepSeek V3.
|
||||
|
||||
This module implements speculative decoding using the Multi-Token Prediction (MTP)
|
||||
layer from DeepSeek V3. The key difference from standard speculative decoding is
|
||||
that MTP requires hidden states from the main model, not just token predictions.
|
||||
|
||||
Based on vLLM/SGLang research:
|
||||
- 81-82% acceptance rate with k=1
|
||||
- 1.5-2x speedup at low QPS
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models import cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Generation stream for async operations
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPGenerationResponse:
|
||||
"""Response from MTP speculative generation.
|
||||
|
||||
Attributes:
|
||||
text: The next segment of decoded text.
|
||||
token: The next token.
|
||||
logprobs: A vector of log probabilities.
|
||||
from_draft: Whether the token was generated by the MTP draft module.
|
||||
prompt_tokens: The number of tokens in the prompt.
|
||||
prompt_tps: The prompt processing tokens-per-second.
|
||||
generation_tokens: The number of generated tokens.
|
||||
generation_tps: The tokens-per-second for generation.
|
||||
peak_memory: The peak memory used so far in GB.
|
||||
finish_reason: The reason the response is being sent: "length", "stop" or None.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
"""Quantize KV cache entries if needed."""
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache): # pyright: ignore[reportAny]
|
||||
if (
|
||||
hasattr(c, "to_quantized") # pyright: ignore[reportAny]
|
||||
and hasattr(c, "offset") # pyright: ignore[reportAny]
|
||||
and c.offset >= quantized_kv_start # pyright: ignore[reportAny]
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
class ModelWithHiddenStates(nn.Module):
|
||||
"""Wrapper to extract hidden states before lm_head.
|
||||
|
||||
This wrapper allows capturing the hidden states from the transformer
|
||||
layers before the final lm_head projection, which is needed for MTP.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self._base = base_model
|
||||
|
||||
def forward_with_hidden(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass that returns both logits and hidden states.
|
||||
|
||||
Args:
|
||||
inputs: Input token ids
|
||||
model_cache: KV cache
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, hidden_states)
|
||||
"""
|
||||
# Call the inner model (transformer layers + norm)
|
||||
hidden = cast(mx.array, self._base.model(inputs, model_cache)) # pyright: ignore[reportUnknownMemberType]
|
||||
# Get logits from lm_head
|
||||
logits = cast(mx.array, self._base.lm_head(hidden)) # pyright: ignore[reportUnknownMemberType]
|
||||
return logits, hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> mx.array:
|
||||
"""Standard forward pass returning only logits."""
|
||||
return cast(mx.array, self._base(inputs, cache=model_cache))
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
"""Access layers for cache creation."""
|
||||
return cast(list[nn.Module], self._base.layers)
|
||||
|
||||
|
||||
def mtp_speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
*,
|
||||
num_draft_tokens: int = 1,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
mtp_cache: KVCache | None = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[int, mx.array, bool], None, None]:
|
||||
"""MTP speculative decoding generator.
|
||||
|
||||
Unlike standard speculative decoding where the draft model only needs tokens,
|
||||
MTP requires the hidden states from the main model. This generator:
|
||||
|
||||
1. Runs the main model to get logits AND hidden states
|
||||
2. Uses MTP module with hidden state + sampled token to predict next token
|
||||
3. Verifies MTP predictions with the main model
|
||||
4. Accepts/rejects based on matching
|
||||
|
||||
Args:
|
||||
prompt: The input prompt as token ids
|
||||
model: The main model (must support return_hidden=True)
|
||||
mtp_module: The MTP module for draft prediction
|
||||
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
sampler: Optional sampler function for token selection
|
||||
logits_processors: Optional list of logits processors
|
||||
prompt_cache: KV cache for the main model
|
||||
mtp_cache: KV cache for the MTP module
|
||||
prefill_step_size: Step size for prompt processing
|
||||
kv_bits: Bits for KV cache quantization
|
||||
kv_group_size: Group size for KV cache quantization
|
||||
quantized_kv_start: Step to begin cache quantization
|
||||
|
||||
Yields:
|
||||
Tuple of (token, logprobs, from_draft)
|
||||
"""
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens: mx.array | None = None
|
||||
|
||||
# Wrap model to get hidden states
|
||||
wrapped_model = (
|
||||
model
|
||||
if isinstance(model, ModelWithHiddenStates)
|
||||
else ModelWithHiddenStates(model)
|
||||
)
|
||||
|
||||
# Create caches if needed
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
if mtp_cache is None:
|
||||
mtp_cache = KVCache()
|
||||
|
||||
final_sampler: Callable[[mx.array], mx.array] = (
|
||||
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
|
||||
)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(
|
||||
tokens: mx.array | None,
|
||||
logits: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Process logits and sample tokens."""
|
||||
nonlocal logits_processors
|
||||
processed_logits = logits
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
processed_logits = processor(
|
||||
tokens if tokens is not None else mx.array([]), processed_logits
|
||||
)
|
||||
|
||||
logprobs = processed_logits - mx.logsumexp(
|
||||
processed_logits, axis=-1, keepdims=True
|
||||
)
|
||||
sampled = final_sampler(logprobs)
|
||||
return sampled, logprobs
|
||||
|
||||
def _main_model_step_with_hidden(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Run main model step with hidden state return."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits, hidden = wrapped_model.forward_with_hidden(
|
||||
input_y[None], prompt_cache
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
|
||||
|
||||
def _mtp_draft(
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Generate draft token using MTP module."""
|
||||
with mx.stream(generation_stream):
|
||||
logits, new_hidden = mtp_module(
|
||||
hidden_state,
|
||||
draft_token,
|
||||
cache=mtp_cache,
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
sampled, _ = _process_and_sample(None, logits)
|
||||
return sampled, new_hidden
|
||||
|
||||
def _prefill(input_y: mx.array) -> mx.array:
|
||||
"""Prefill the prompt cache."""
|
||||
result_y = input_y
|
||||
while result_y.size > prefill_step_size:
|
||||
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache]) # pyright: ignore[reportAny]
|
||||
result_y = result_y[prefill_step_size:]
|
||||
mx.clear_cache()
|
||||
return result_y
|
||||
|
||||
def _rewind_cache(num_draft: int, num_accept: int) -> None:
|
||||
"""Rewind caches after rejection."""
|
||||
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
|
||||
|
||||
# Prefill phase
|
||||
with mx.stream(generation_stream):
|
||||
y = _prefill(y)
|
||||
|
||||
ntoks = 0
|
||||
num_draft = 0
|
||||
n_accepted = 0
|
||||
last_hidden: mx.array | None = None
|
||||
|
||||
try:
|
||||
# Initial step to get first token and hidden state
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
while ntoks < max_tokens:
|
||||
# Draft phase: use MTP to predict next token
|
||||
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
|
||||
|
||||
if num_draft > 0 and last_hidden is not None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
# Use MTP to draft
|
||||
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
|
||||
mx.eval(draft_token, draft_hidden)
|
||||
|
||||
# Verify with main model
|
||||
# Feed the drafted token to main model
|
||||
verify_input = mx.concatenate([y, draft_token.flatten()])
|
||||
verify_sampled, verify_logprobs, new_hidden = (
|
||||
_main_model_step_with_hidden(verify_input)
|
||||
)
|
||||
mx.eval(verify_sampled, verify_logprobs, new_hidden)
|
||||
|
||||
# Check if draft matches verification
|
||||
draft_token_val = int(draft_token.item())
|
||||
verify_token_val = (
|
||||
int(verify_sampled[0].item())
|
||||
if verify_sampled.shape[0] > 1
|
||||
else int(verify_sampled.item())
|
||||
)
|
||||
|
||||
# Yield the current token (not from draft)
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
if draft_token_val == verify_token_val:
|
||||
# Draft accepted
|
||||
n_accepted += 1
|
||||
ntoks += 1
|
||||
draft_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
yield draft_token_val, draft_logprobs, True
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
# Continue with the token after the draft
|
||||
y = (
|
||||
verify_sampled[-1:]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[-1]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden
|
||||
else:
|
||||
# Draft rejected - rewind and use verified token
|
||||
_rewind_cache(1, 0)
|
||||
y = (
|
||||
verify_sampled[:1]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden[:, :1, :]
|
||||
else:
|
||||
# No drafting, just do normal generation
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
if ntoks % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
finally:
|
||||
_rewind_cache(num_draft, n_accepted)
|
||||
|
||||
|
||||
def mtp_speculative_generate(
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
num_draft_tokens: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int | None = None,
|
||||
) -> Generator[MTPGenerationResponse, None, None]:
|
||||
"""High-level MTP speculative generation with text output.
|
||||
|
||||
Args:
|
||||
model: The main model
|
||||
mtp_module: The MTP module for draft prediction
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
prompt: Input prompt (string, array, or token list)
|
||||
max_tokens: Maximum tokens to generate
|
||||
sampler: Optional sampler function
|
||||
logits_processors: Optional logits processors
|
||||
prompt_cache: Optional KV cache
|
||||
num_draft_tokens: Number of draft tokens
|
||||
prefill_step_size: Prefill step size
|
||||
kv_group_size: KV group size
|
||||
kv_bits: KV bits
|
||||
|
||||
Yields:
|
||||
MTPGenerationResponse objects with text and metadata
|
||||
"""
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
bos_token: str | None = getattr(tokenizer, "bos_token", None)
|
||||
add_special_tokens = bos_token is None or not prompt.startswith(
|
||||
str(bos_token)
|
||||
)
|
||||
encoded: list[int] = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
prompt = mx.array(encoded)
|
||||
else:
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer: Any = tokenizer.detokenizer
|
||||
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
|
||||
|
||||
token_generator = mtp_speculative_generate_step(
|
||||
prompt,
|
||||
model,
|
||||
mtp_module,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_tps = 0.0
|
||||
token = 0
|
||||
logprobs: mx.array = mx.array([0.0])
|
||||
from_draft = False
|
||||
n = 0
|
||||
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = float(prompt.size) / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token) # pyright: ignore[reportAny]
|
||||
if (n + 1) == max_tokens:
|
||||
break
|
||||
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment), # pyright: ignore[reportAny]
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize() # pyright: ignore[reportAny]
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment), # pyright: ignore[reportAny]
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in eos_token_ids else "length",
|
||||
)
|
||||
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MTP module."""
|
||||
426
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
426
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Unit tests for MTP module components."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTP_LAYER_INDEX,
|
||||
MTPModule,
|
||||
MTPTransformerBlock,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
|
||||
class MockModelArgs:
|
||||
"""Mock ModelArgs for testing without importing deepseek_v3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
vocab_size: int = 1000,
|
||||
q_lora_rank: int | None = None,
|
||||
kv_lora_rank: int = 64,
|
||||
qk_rope_head_dim: int = 16,
|
||||
v_head_dim: int = 32,
|
||||
qk_nope_head_dim: int = 32,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
attention_bias: bool = False,
|
||||
max_position_embeddings: int = 2048,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
|
||||
class TestExtractMTPWeights:
|
||||
"""Tests for extract_mtp_weights function."""
|
||||
|
||||
def test_extracts_layer_61_weights(self) -> None:
|
||||
"""Should extract only layer 61 weights."""
|
||||
weights = {
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.61.enorm.weight": mx.ones((10,)),
|
||||
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
|
||||
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
|
||||
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.embed_tokens.weight": mx.zeros((100, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 3
|
||||
assert "enorm.weight" in mtp_weights
|
||||
assert "hnorm.weight" in mtp_weights
|
||||
assert "eh_proj.weight" in mtp_weights
|
||||
# Check values are preserved
|
||||
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
|
||||
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
|
||||
|
||||
def test_returns_empty_dict_when_no_layer_61(self) -> None:
|
||||
"""Should return empty dict when layer 61 doesn't exist."""
|
||||
weights = {
|
||||
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 0
|
||||
|
||||
def test_handles_nested_layer_61_weights(self) -> None:
|
||||
"""Should handle nested weight paths like self_attn.q_proj.weight."""
|
||||
weights = {
|
||||
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
|
||||
(10, 10)
|
||||
),
|
||||
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert "self_attn.q_a_proj.weight" in mtp_weights
|
||||
assert "mlp.gate_proj.weight" in mtp_weights
|
||||
|
||||
|
||||
class TestMTPTransformerBlock:
|
||||
"""Tests for MTPTransformerBlock."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64, intermediate_size=128, num_attention_heads=2
|
||||
)
|
||||
|
||||
def test_forward_shape(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should preserve input shape."""
|
||||
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
|
||||
output = block(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_forward_with_mask(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should work with attention mask."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
# Create causal mask
|
||||
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
|
||||
|
||||
output = block(x, mask=mask)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestMTPModule:
|
||||
"""Tests for MTPModule."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def shared_components(
|
||||
self, config: MockModelArgs
|
||||
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
return embedding, lm_head, output_norm
|
||||
|
||||
def test_initialization(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should initialize with correct components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
assert mtp.hnorm is not None
|
||||
assert mtp.enorm is not None
|
||||
assert mtp.eh_proj is not None
|
||||
assert mtp.transformer_block is not None
|
||||
|
||||
def test_forward_output_shapes(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""Forward pass should return correct output shapes."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1
|
||||
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
|
||||
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
|
||||
|
||||
logits, new_hidden = mtp(hidden_state, draft_token)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_shares_embedding_and_lm_head(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should use shared embedding and lm_head."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Verify they're the same objects
|
||||
assert mtp._shared_embedding is embedding # pyright: ignore[reportPrivateUsage]
|
||||
assert mtp._shared_lm_head is lm_head # pyright: ignore[reportPrivateUsage]
|
||||
assert mtp._output_norm is output_norm # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class TestLoadMTPWeights:
|
||||
"""Tests for load_mtp_weights_into_module."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
|
||||
"""Should load enorm and hnorm weights."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Create test weights
|
||||
test_enorm = mx.ones((config.hidden_size,)) * 3.0
|
||||
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
|
||||
mtp_weights = {
|
||||
"enorm.weight": test_enorm,
|
||||
"hnorm.weight": test_hnorm,
|
||||
}
|
||||
|
||||
load_mtp_weights_into_module(mtp, mtp_weights)
|
||||
|
||||
assert mx.allclose(mtp.enorm.weight, test_enorm) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
assert mx.allclose(mtp.hnorm.weight, test_hnorm) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
|
||||
class TestSanitizePatch:
|
||||
"""Tests for the sanitize patching logic."""
|
||||
|
||||
def test_patch_preserves_layer_61(self) -> None:
|
||||
"""Patching sanitize should preserve layer 61 weights."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
# Get original sanitize behavior
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
try:
|
||||
# Apply patch
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
# Note: we can't easily test the full sanitize without a real model
|
||||
# This test verifies the patch is applied
|
||||
assert model_cls.sanitize is not original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
# Verify restore worked
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
def test_restore_sanitize(self) -> None:
|
||||
"""Restoring sanitize should return to original behavior."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is not original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
def test_double_patch_is_safe(self) -> None:
|
||||
"""Calling patch twice should be safe (idempotent)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
try:
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
patched_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
# Patch again - should be no-op
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is patched_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
class TestModelIdDetection:
|
||||
"""Tests for DeepSeek V3 model ID detection."""
|
||||
|
||||
def test_detects_deepseek_v3(self) -> None:
|
||||
"""Should detect DeepSeek V3 model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
|
||||
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
|
||||
|
||||
def test_detects_deepseek_r1(self) -> None:
|
||||
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
|
||||
|
||||
def test_rejects_non_deepseek(self) -> None:
|
||||
"""Should reject non-DeepSeek model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
|
||||
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
|
||||
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Detection should be case insensitive."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
|
||||
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
|
||||
|
||||
|
||||
class TestFlattenParams:
|
||||
"""Tests for parameter flattening utility."""
|
||||
|
||||
def test_flattens_nested_dict(self) -> None:
|
||||
"""Should flatten nested parameter dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_flatten_params, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": {
|
||||
"layers": {
|
||||
"0": {
|
||||
"weight": mx.zeros((10,)),
|
||||
}
|
||||
},
|
||||
"embed": mx.ones((5,)),
|
||||
}
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert "model.layers.0.weight" in flat
|
||||
assert "model.embed" in flat
|
||||
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
|
||||
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
|
||||
|
||||
def test_handles_flat_dict(self) -> None:
|
||||
"""Should handle already-flat dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_flatten_params, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
params = {
|
||||
"weight": mx.zeros((10,)),
|
||||
"bias": mx.ones((10,)),
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert flat == params
|
||||
260
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
260
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Unit tests for MTP speculative decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import (
|
||||
ModelWithHiddenStates,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
class MockModel(nn.Module):
|
||||
"""Mock model for testing speculative decoding."""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Create simple model components
|
||||
self.model = MockInnerModel(hidden_size)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list[object] | None = None,
|
||||
) -> mx.array:
|
||||
hidden = self.model(inputs, cache)
|
||||
return self.lm_head(hidden)
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Linear]:
|
||||
return self._layers
|
||||
|
||||
|
||||
class MockInnerModel(nn.Module):
|
||||
"""Mock inner model (like DeepseekV3Model)."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, hidden_size)
|
||||
self.norm = nn.RMSNorm(hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list[object] | None = None,
|
||||
) -> mx.array:
|
||||
# Simple embedding + norm
|
||||
embedded = self.embed_tokens(inputs)
|
||||
return self.norm(embedded)
|
||||
|
||||
|
||||
class TestModelWithHiddenStates:
|
||||
"""Tests for ModelWithHiddenStates wrapper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel(hidden_size=64, vocab_size=100)
|
||||
|
||||
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
|
||||
"""Standard forward should return logits."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits = wrapped.forward(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
|
||||
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
|
||||
"""Forward with hidden should return (logits, hidden)."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits, hidden = wrapped.forward_with_hidden(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
assert hidden.shape == (1, 3, mock_model.hidden_size)
|
||||
|
||||
def test_layers_property(self, mock_model: MockModel) -> None:
|
||||
"""Should expose layers property from base model."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
|
||||
assert wrapped.layers == mock_model.layers
|
||||
assert len(wrapped.layers) == 3
|
||||
|
||||
|
||||
class TestMaybeQuantizeKVCache:
|
||||
"""Tests for KV cache quantization."""
|
||||
|
||||
def test_no_quantization_when_bits_none(self) -> None:
|
||||
"""Should not quantize when kv_bits is None."""
|
||||
cache = [MockCache(offset=100)]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
cache,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=None,
|
||||
)
|
||||
|
||||
# Cache should be unchanged
|
||||
assert not hasattr(cache[0], "quantized")
|
||||
|
||||
def test_respects_quantized_kv_start(self) -> None:
|
||||
"""Should only quantize caches past the start threshold."""
|
||||
cache_below = MockCache(offset=30)
|
||||
cache_above = MockCache(offset=100)
|
||||
caches = [cache_below, cache_above]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
caches,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=4,
|
||||
)
|
||||
|
||||
# Only cache_above should be quantized
|
||||
assert not getattr(cache_below, "was_quantized", False)
|
||||
assert getattr(caches[1], "was_quantized", False)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock KV cache for testing."""
|
||||
|
||||
def __init__(self, offset: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.was_quantized = False
|
||||
|
||||
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
|
||||
quantized = MockCache(self.offset)
|
||||
quantized.was_quantized = True
|
||||
return quantized
|
||||
|
||||
|
||||
class TestSpeculativeDecodingLogic:
|
||||
"""Tests for the core speculative decoding logic."""
|
||||
|
||||
def test_draft_acceptance_identical_tokens(self) -> None:
|
||||
"""When draft matches verification, both should be accepted."""
|
||||
# This tests the logic, not the full generator
|
||||
draft_token = 42
|
||||
verify_token = 42
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert accepted
|
||||
|
||||
def test_draft_rejection_different_tokens(self) -> None:
|
||||
"""When draft differs from verification, draft should be rejected."""
|
||||
draft_token = 42
|
||||
verify_token = 99
|
||||
|
||||
accepted = int(draft_token) == int(verify_token)
|
||||
assert not accepted
|
||||
|
||||
|
||||
class TestMTPGenerationResponse:
|
||||
"""Tests for MTPGenerationResponse dataclass."""
|
||||
|
||||
def test_response_creation(self) -> None:
|
||||
"""Should create response with all fields."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="Hello",
|
||||
token=42,
|
||||
logprobs=mx.array([0.1, 0.2]),
|
||||
from_draft=True,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=5,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert response.text == "Hello"
|
||||
assert response.token == 42
|
||||
assert response.from_draft is True
|
||||
assert response.finish_reason is None
|
||||
|
||||
def test_response_with_finish_reason(self) -> None:
|
||||
"""Should handle finish_reason."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="",
|
||||
token=0,
|
||||
logprobs=mx.array([0.0]),
|
||||
from_draft=False,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=100,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason="length",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full MTP pipeline."""
|
||||
|
||||
def test_mtp_module_with_mock_model(self) -> None:
|
||||
"""Test MTP module can be created and run with mock components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Create mock config
|
||||
class MockConfig:
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_attention_heads = 2
|
||||
num_key_value_heads = 2
|
||||
rms_norm_eps = 1e-6
|
||||
q_lora_rank = None
|
||||
kv_lora_rank = 32
|
||||
qk_rope_head_dim = 8
|
||||
v_head_dim = 16
|
||||
qk_nope_head_dim = 16
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
attention_bias = False
|
||||
max_position_embeddings = 2048
|
||||
|
||||
config = MockConfig()
|
||||
embedding = nn.Embedding(100, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
hidden = mx.random.normal((1, 1, config.hidden_size))
|
||||
token = mx.array([[5]])
|
||||
|
||||
logits, new_hidden = mtp(hidden, token)
|
||||
|
||||
assert logits.shape == (1, 1, 100)
|
||||
assert new_hidden.shape == (1, 1, config.hidden_size)
|
||||
# Verify outputs are valid (not NaN)
|
||||
# mx.isnan forces GPU evaluation which may fail on CI runners
|
||||
# with incompatible Metal toolchains (e.g. missing steel_gemm kernels)
|
||||
try:
|
||||
assert not mx.any(mx.isnan(logits))
|
||||
assert not mx.any(mx.isnan(new_hidden))
|
||||
except RuntimeError as e:
|
||||
if "Unable to load kernel" in str(e):
|
||||
pytest.skip(f"Metal kernel not available on this device: {e}")
|
||||
raise
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -23,6 +24,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
MTP_ENABLED,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -64,6 +66,144 @@ from exo.worker.runner.bootstrap import logger
|
||||
Group = mx.distributed.Group
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
# MTP (Multi-Token Prediction) support for DeepSeek V3
|
||||
MTP_LAYER_INDEX = 61
|
||||
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def _is_deepseek_v3_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is DeepSeek V3."""
|
||||
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def _might_be_deepseek_v3(model_id: str) -> bool:
|
||||
"""Check if model ID suggests this might be DeepSeek V3."""
|
||||
model_id_lower = model_id.lower()
|
||||
return "deepseek" in model_id_lower and (
|
||||
"v3" in model_id_lower or "r1" in model_id_lower
|
||||
)
|
||||
|
||||
|
||||
def _patch_deepseek_sanitize_for_mtp() -> None:
|
||||
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights."""
|
||||
global _original_deepseek_sanitize
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
if _original_deepseek_sanitize is not None:
|
||||
return
|
||||
|
||||
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
|
||||
|
||||
def sanitize_with_mtp(
|
||||
self: DeepSeekV3Model, weights: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if _original_deepseek_sanitize is None:
|
||||
raise RuntimeError(
|
||||
"_original_deepseek_sanitize is None - patch not applied correctly"
|
||||
)
|
||||
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
|
||||
mtp_weights = {
|
||||
k: v
|
||||
for k, v in weights.items() # pyright: ignore[reportAny]
|
||||
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
|
||||
}
|
||||
return {**original_result, **mtp_weights}
|
||||
|
||||
DeepSeekV3Model.sanitize = sanitize_with_mtp
|
||||
|
||||
|
||||
def _restore_deepseek_sanitize() -> None:
|
||||
"""Restore the original DeepSeek V3 sanitize method."""
|
||||
global _original_deepseek_sanitize
|
||||
if _original_deepseek_sanitize is None:
|
||||
return
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
|
||||
_original_deepseek_sanitize = None
|
||||
|
||||
|
||||
def _flatten_params(
|
||||
params: dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> dict[str, mx.array]:
|
||||
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
|
||||
result: dict[str, mx.array] = {}
|
||||
for key, value in params.items(): # pyright: ignore[reportAny]
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, mx.array):
|
||||
result[full_key] = value
|
||||
elif isinstance(value, dict):
|
||||
result.update(_flatten_params(cast(dict[str, Any], value), full_key))
|
||||
return result
|
||||
|
||||
|
||||
def _extract_mtp_module(model: nn.Module) -> Any | None:
|
||||
"""Extract MTP module from a loaded DeepSeek V3 model."""
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTPModule,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
try:
|
||||
inner_model: Any = getattr(model, "model", None)
|
||||
if inner_model is None or not hasattr(inner_model, "layers"): # pyright: ignore[reportAny]
|
||||
logger.debug("Model doesn't have expected structure for MTP extraction")
|
||||
return None
|
||||
|
||||
layers: list[nn.Module] = inner_model.layers # pyright: ignore[reportAny]
|
||||
if len(layers) <= MTP_LAYER_INDEX:
|
||||
logger.debug(
|
||||
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
|
||||
)
|
||||
return None
|
||||
|
||||
config: Any = getattr(model, "args", None)
|
||||
if config is None:
|
||||
logger.debug("Could not get model config for MTP module")
|
||||
return None
|
||||
|
||||
embed_tokens: Any = getattr(inner_model, "embed_tokens", None) # pyright: ignore[reportAny]
|
||||
lm_head: Any = getattr(model, "lm_head", None)
|
||||
norm: Any = getattr(inner_model, "norm", None) # pyright: ignore[reportAny]
|
||||
|
||||
if embed_tokens is None or lm_head is None or norm is None:
|
||||
logger.debug("Could not get required model components for MTP")
|
||||
return None
|
||||
|
||||
mtp_module = MTPModule(
|
||||
config=config, # pyright: ignore[reportAny]
|
||||
shared_embedding=embed_tokens, # pyright: ignore[reportAny]
|
||||
shared_lm_head=lm_head, # pyright: ignore[reportAny]
|
||||
output_norm=norm, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
|
||||
model_weights = _flatten_params(raw_params)
|
||||
mtp_weights = extract_mtp_weights(model_weights)
|
||||
|
||||
if not mtp_weights:
|
||||
logger.debug("No MTP weights found in model parameters")
|
||||
return None
|
||||
|
||||
load_mtp_weights_into_module(mtp_module, mtp_weights)
|
||||
|
||||
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
|
||||
inner_model.layers = new_layers # noqa: B010
|
||||
|
||||
logger.info(
|
||||
f"Extracted MTP module, main model now has {len(new_layers)} layers"
|
||||
)
|
||||
return mtp_module
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract MTP module: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -81,6 +221,30 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -172,28 +336,52 @@ def load_mlx_items(
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
mtp_module = None
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
# Patch sanitize for MTP if this might be DeepSeek V3
|
||||
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
|
||||
if should_try_mtp:
|
||||
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
try:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=not should_try_mtp)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
# Extract MTP module if available
|
||||
if should_try_mtp and _is_deepseek_v3_model(model):
|
||||
mtp_module = _extract_mtp_module(model)
|
||||
if mtp_module is not None:
|
||||
logger.info("Successfully extracted MTP module from DeepSeek V3")
|
||||
|
||||
finally:
|
||||
if should_try_mtp:
|
||||
_restore_deepseek_sanitize()
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
# Store MTP module on the model for later access
|
||||
if mtp_module is not None:
|
||||
model.mtp_module = mtp_module # noqa: B010
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@@ -285,12 +473,10 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# For GLM-5 and GLM-4.7
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
@@ -355,13 +541,7 @@ def load_tokenizer_for_model_id(
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(
|
||||
hf_tokenizer,
|
||||
eos_token_ids=eos_token_ids,
|
||||
tool_call_start="<|tool_calls_section_begin|>",
|
||||
tool_call_end="<|tool_calls_section_end|>",
|
||||
tool_parser=_parse_kimi_tool_calls,
|
||||
)
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
@@ -573,61 +753,3 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(text: str):
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
_tool_call_split_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
|
||||
)
|
||||
|
||||
def _parse_single_tool(text: str) -> dict[str, Any]:
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError("No tool call found.")
|
||||
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
|
||||
func_name = func_name_match.group(2) # e.g. "get_weather"
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError("No tool call arguments found.")
|
||||
func_args = func_args_match.group(1)
|
||||
try:
|
||||
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
arg_dct = None
|
||||
|
||||
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
|
||||
|
||||
tool_matches = _tool_call_split_regex.findall(text)
|
||||
if tool_matches:
|
||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||
else:
|
||||
return [_parse_single_tool(text)]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
@@ -16,7 +17,7 @@ from exo.shared.types.commands import (
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
@@ -32,7 +33,6 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -63,12 +63,14 @@ class Worker:
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
@@ -83,8 +85,6 @@ class Worker:
|
||||
self._nack_base_seconds: float = 0.5
|
||||
self._nack_cap_seconds: float = 10.0
|
||||
|
||||
self._system_id = SystemId()
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
@@ -131,7 +131,7 @@ class Worker:
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
event_id = f_event.event.event_id
|
||||
@@ -211,7 +211,7 @@ class Worker:
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self._system_id,
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
@@ -224,22 +224,15 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -277,18 +270,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self._start_runner_task(modified_task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self._start_runner_task(task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -311,7 +304,7 @@ class Worker:
|
||||
)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self._system_id,
|
||||
origin=self.node_id,
|
||||
command=RequestEventLog(since_idx=since_idx),
|
||||
)
|
||||
)
|
||||
@@ -327,6 +320,8 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
@@ -338,16 +333,15 @@ class Worker:
|
||||
return runner
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self._system_id,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -54,14 +53,13 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -272,7 +270,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -286,7 +284,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -294,33 +292,16 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,7 +15,6 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -39,7 +38,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import base64
|
||||
import math
|
||||
import json
|
||||
import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -15,6 +15,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
@@ -87,12 +88,9 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
from .tool_parsers import ToolParser, make_mlx_parser
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
@@ -114,7 +112,6 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
@@ -132,16 +129,11 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
tool_parser: ToolParser | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -154,7 +146,6 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -200,28 +191,19 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
tool_parser = make_mlx_parser(
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -229,6 +211,8 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -240,31 +224,16 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert inference_model
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -272,8 +241,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -293,9 +262,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
assert inference_model
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -305,7 +274,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -320,25 +289,34 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -386,7 +364,6 @@ def main(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -411,7 +388,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -424,9 +401,7 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -476,7 +451,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -489,9 +464,7 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -550,20 +523,14 @@ def main(
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
if not was_cancelled:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -577,8 +544,21 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
@@ -635,9 +615,9 @@ def parse_gpt_oss(
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse]:
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -758,55 +738,218 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse], tool_parser: ToolParser
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
if response.text.startswith(tool_parser.start_parsing):
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.text.endswith(tool_parser.end_parsing):
|
||||
# parse the actual tool calls from the tool call text
|
||||
parsed = tool_parser.parse_tool_calls(
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if parsed is not None:
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||
)
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||
)
|
||||
response.text = "".join(tool_call_text_parts)
|
||||
yield response
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"tool call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"text": "".join(tool_call_text_parts),
|
||||
"token": 0,
|
||||
}
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"toll call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
yield GenerationResponse(
|
||||
text=tool_call_start + "".join(tool_call_text_parts),
|
||||
token=0,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
continue
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
original_func_name = func_name_match.group(1)
|
||||
tool_id = func_name_match.group(2)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = original_func_name[original_func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(
|
||||
id=f"{original_func_name}:{tool_id}",
|
||||
name=func_name,
|
||||
arguments=arg_dct, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
raw_id: object = obj.get("id")
|
||||
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||
return ToolCallItem(
|
||||
**extra,
|
||||
name=name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -47,11 +47,9 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -62,8 +60,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -71,7 +69,6 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -86,7 +83,6 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -101,8 +97,6 @@ class RunnerSupervisor:
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -118,6 +112,14 @@ class RunnerSupervisor:
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
@@ -139,17 +141,6 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -191,7 +182,7 @@ class RunnerSupervisor:
|
||||
logger.info("Checking runner's status")
|
||||
if self.runner_process.is_alive():
|
||||
logger.info("Runner was found to be alive, attempting to join process")
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
await to_thread.run_sync(self.runner_process.join, 1)
|
||||
rc = self.runner_process.exitcode
|
||||
logger.info(f"RunnerSupervisor exited with exit code {rc}")
|
||||
if rc == 0:
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParser:
|
||||
start_parsing: str
|
||||
end_parsing: str
|
||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
||||
|
||||
|
||||
def make_mlx_parser(
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> ToolParser:
|
||||
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix(tool_call_start)
|
||||
text = text.removesuffix(tool_call_end)
|
||||
parsed = tool_parser(text)
|
||||
if isinstance(parsed, list):
|
||||
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
|
||||
else:
|
||||
return [ToolCallItem.model_validate(_flatten(parsed))]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return ToolParser(
|
||||
start_parsing=tool_call_start,
|
||||
end_parsing=tool_call_end,
|
||||
parse_tool_calls=parse_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
# TODO / example code:
|
||||
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix("<tool_call>")
|
||||
text = text.removesuffix("</tool_call>")
|
||||
top_level = {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else v
|
||||
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
|
||||
}
|
||||
return [ToolCallItem.model_validate(top_level)]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
start_parsing="<tool_call>",
|
||||
end_parsing="</tool_call>",
|
||||
parse_tool_calls=_parse_json_calls,
|
||||
)
|
||||
|
||||
|
||||
def infer_tool_parser(chat_template: str) -> ToolParser | None:
|
||||
"""Attempt to auto-infer a tool parser from the chat template."""
|
||||
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
|
||||
return json_tool_parser
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user