mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
11 Commits
fix-partia
...
iroh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30021ea887 | ||
|
|
6c322ebb72 | ||
|
|
2ebe6216b4 | ||
|
|
f54c80b121 | ||
|
|
48b8f86395 | ||
|
|
5cbd6377a2 | ||
|
|
8f01523ddb | ||
|
|
3addeadea8 | ||
|
|
f2be929211 | ||
|
|
83af8c63fa | ||
|
|
eccc6298d1 |
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
150
Cargo.lock
generated
150
Cargo.lock
generated
@@ -141,12 +141,6 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
@@ -304,19 +298,6 @@ version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bimap"
|
||||
version = "0.6.3"
|
||||
@@ -516,15 +497,6 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -701,17 +673,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "delegate"
|
||||
version = "0.13.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.10"
|
||||
@@ -746,29 +707,6 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn 2.0.111",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -938,25 +876,16 @@ dependencies = [
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"impl-trait-for-tuples",
|
||||
"libp2p",
|
||||
"futures-lite",
|
||||
"log",
|
||||
"networking",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"thiserror 2.0.17",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1640,17 +1569,6 @@ dependencies = [
|
||||
"xmltree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "impl-trait-for-tuples"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
@@ -1829,12 +1747,6 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "libp2p"
|
||||
version = "0.56.0"
|
||||
@@ -2823,20 +2735,13 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"futures-lite",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2918,17 +2823,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -3279,28 +3173,14 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"either",
|
||||
"hashbrown 0.16.1",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"memoffset",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"rust_decimal",
|
||||
"smallvec",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
@@ -3741,16 +3621,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
@@ -4615,24 +4485,12 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "1.3.0"
|
||||
@@ -4713,10 +4571,6 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "util"
|
||||
version = "0.0.1"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.19.0"
|
||||
|
||||
30
Cargo.toml
30
Cargo.toml
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -24,51 +23,22 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
@@ -103,7 +103,7 @@
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
return caps.includes("thinking_toggle") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
|
||||
@@ -185,7 +185,11 @@
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
|
||||
@@ -59,13 +59,14 @@
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -176,92 +177,90 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
{params.size.toUpperCase()}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -311,7 +310,7 @@
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition()
|
||||
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: "Pipeline" | "Tensor";
|
||||
runtime?: "MlxRing" | "MlxJaccl";
|
||||
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
onLaunch?: () => void;
|
||||
tags?: string[];
|
||||
apiPreview?: PlacementPreview | null;
|
||||
@@ -348,7 +348,7 @@
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === "MlxJaccl");
|
||||
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
@@ -575,7 +575,7 @@
|
||||
>
|
||||
{runtime === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: runtime === "MlxJaccl"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: runtime}
|
||||
</span>
|
||||
|
||||
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxJaccl";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
memory_delta_by_node: Record<string, number> | null;
|
||||
error: string | null;
|
||||
@@ -219,6 +219,7 @@ interface RawStateResponse {
|
||||
string,
|
||||
{
|
||||
MlxRingInstance?: Instance;
|
||||
MlxIbvInstance?: Instance;
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
@@ -249,20 +250,6 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// MetaInstances (declarative instance constraints)
|
||||
metaInstances?: Record<string, MetaInstanceData>;
|
||||
}
|
||||
|
||||
export interface MetaInstanceData {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: string;
|
||||
instanceMeta: string;
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -319,13 +306,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "auto"
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
| "1024x1536"
|
||||
| "1536x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -349,7 +337,7 @@ export interface EditingImage {
|
||||
}
|
||||
|
||||
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
size: "auto",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
@@ -550,7 +538,6 @@ class AppStore {
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
metaInstances = $state<Record<string, MetaInstanceData>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
Record<
|
||||
@@ -909,7 +896,11 @@ class AppStore {
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
@@ -1283,8 +1274,6 @@ class AppStore {
|
||||
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
|
||||
// RDMA ctl status per node
|
||||
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
|
||||
// MetaInstances
|
||||
this.metaInstances = data.metaInstances ?? {};
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
@@ -3056,7 +3045,6 @@ export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260217+50487b41"; in
|
||||
version = let v = "0.30.7.dev20260218+14841977"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
@@ -49,8 +49,8 @@ let
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
|
||||
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
|
||||
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
|
||||
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"mlx-lm==0.30.7",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "4bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "8bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "5bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "3bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -5,7 +5,6 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
@@ -25,17 +24,17 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
@@ -44,34 +43,11 @@ pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
#tracing-subscriber = "0.3"
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
TODO: do something here....
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -10,31 +6,17 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
pub struct AllowThreads<F>(pub(crate) F);
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,217 +1,5 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = ResultExt)]
|
||||
impl<T, E> Result<T, E>
|
||||
where
|
||||
E: ToString,
|
||||
{
|
||||
fn pyerr(self) -> PyResult<T> {
|
||||
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
{
|
||||
AllowThreads::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Future> FutureExt for T {}
|
||||
|
||||
#[ext(pub, name = PyErrExt)]
|
||||
impl PyErr {
|
||||
fn receiver_channel_closed() -> Self {
|
||||
PyConnectionError::new_err("Receiver channel closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
match self {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
// write error back to python
|
||||
e.write_unraisable(py, None);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
mod allow_threading;
|
||||
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
|
||||
@@ -1,572 +0,0 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
@@ -1,81 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
||||
|
||||
ct = asyncio.create_task(_await_cons(h))
|
||||
mt = asyncio.create_task(_await_msg(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
|
||||
|
||||
async def _await_cons(h: NetworkingHandle):
|
||||
while True:
|
||||
c = await h.connection_update_recv()
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
|
||||
|
||||
async def _await_msg(h: NetworkingHandle):
|
||||
while True:
|
||||
m = await h.gossipsub_recv()
|
||||
print(f"PYTHON: message: {m}")
|
||||
@@ -5,7 +5,6 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
|
||||
@@ -13,27 +12,14 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
@@ -41,4 +27,4 @@ keccak-const = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -10,65 +6,4 @@ async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
@@ -1,383 +0,0 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
@@ -1 +0,0 @@
|
||||
pub mod wakerdeque;
|
||||
@@ -1,55 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -47,6 +47,7 @@ class DownloadCoordinator:
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool = False
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
@@ -62,6 +63,8 @@ class DownloadCoordinator:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
@@ -107,13 +110,17 @@ class DownloadCoordinator:
|
||||
self._last_progress_time[model_id] = current_time()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
logger.info(
|
||||
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
|
||||
)
|
||||
if not self.offline:
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
@@ -202,6 +209,20 @@ class DownloadCoordinator:
|
||||
)
|
||||
return
|
||||
|
||||
if self.offline:
|
||||
logger.warning(
|
||||
f"Offline mode: model {model_id} is not fully available locally, cannot download"
|
||||
)
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=f"Model files not found locally in offline mode: {model_id}",
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
@@ -314,17 +335,7 @@ class DownloadCoordinator:
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if (
|
||||
progress.downloaded_bytes.in_bytes
|
||||
>= progress.total_bytes.in_bytes
|
||||
> 0
|
||||
):
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.downloaded_bytes.in_bytes == 0:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -448,12 +448,13 @@ async def download_file_with_retry(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
model_id, revision, path, target_dir, on_progress, skip_internet
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
@@ -487,10 +488,14 @@ async def _download_file(
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
if skip_internet:
|
||||
return target_path
|
||||
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
@@ -510,6 +515,11 @@ async def _download_file(
|
||||
)
|
||||
return target_path
|
||||
|
||||
if skip_internet:
|
||||
raise FileNotFoundError(
|
||||
f"File {path} not found locally and cannot download in offline mode"
|
||||
)
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -814,6 +824,7 @@ async def download_shard(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
on_connection_lost=on_connection_lost,
|
||||
skip_internet=skip_internet,
|
||||
)
|
||||
|
||||
if not skip_download:
|
||||
|
||||
230
src/exo/download/tests/test_offline_mode.py
Normal file
230
src/exo/download/tests/test_offline_mode.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Tests for offline/air-gapped mode."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
download_file_with_retry,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.worker.downloads import FileListEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestDownloadFileOffline:
|
||||
"""Tests for _download_file with skip_internet=True."""
|
||||
|
||||
async def test_returns_local_file_without_http_verification(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists locally, return it immediately
|
||||
without making any HTTP calls (no file_meta verification)."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"model weights data")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_raises_file_not_found_for_missing_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file does NOT exist locally,
|
||||
raise FileNotFoundError instead of attempting download."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="offline mode"):
|
||||
await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"missing_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
async def test_returns_local_file_in_subdirectory(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists in a subdirectory,
|
||||
return it without HTTP calls."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
subdir = target_dir / "transformer"
|
||||
await aios.makedirs(subdir, exist_ok=True)
|
||||
|
||||
local_file = subdir / "diffusion_pytorch_model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"weights")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"transformer/diffusion_pytorch_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
|
||||
class TestDownloadFileWithRetryOffline:
|
||||
"""Tests for download_file_with_retry with skip_internet=True."""
|
||||
|
||||
async def test_propagates_skip_internet_to_download_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify skip_internet is passed through to _download_file."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "config.json"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b'{"model_type": "qwen2"}')
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_file_not_found_does_not_retry(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""FileNotFoundError from offline mode should not trigger retries."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"nonexistent.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
|
||||
class TestFetchFileListOffline:
|
||||
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
|
||||
|
||||
async def test_uses_cached_file_list(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and cache file exists, use it without network."""
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=200),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
assert result == cached_list
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
async def test_falls_back_to_local_directory_scan(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and no cache but local files exist,
|
||||
build file list from local directory."""
|
||||
import json
|
||||
|
||||
model_dir = temp_models_dir / model_id.normalize()
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(model_dir / "config.json", "w") as f:
|
||||
await f.write('{"model_type": "qwen2"}')
|
||||
|
||||
index_data = {
|
||||
"metadata": {},
|
||||
"weight_map": {"model.layers.0.weight": "model.safetensors"},
|
||||
}
|
||||
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
|
||||
await f.write(json.dumps(index_data))
|
||||
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
|
||||
await f.write(b"x" * 500)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
mock_fetch.assert_not_called()
|
||||
paths = {entry.path for entry in result}
|
||||
assert "config.json" in paths
|
||||
assert "model.safetensors" in paths
|
||||
|
||||
async def test_raises_when_no_cache_and_no_local_files(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and neither cache nor local files exist,
|
||||
raise FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="No internet"):
|
||||
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
|
||||
@@ -39,6 +39,7 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -68,6 +69,7 @@ class Node:
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -132,6 +134,7 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
args.offline,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -222,6 +225,7 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -254,12 +258,15 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
@@ -282,6 +289,7 @@ class Args(CamelCaseModel):
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -329,6 +337,11 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -71,11 +71,8 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
CreateMetaInstanceParams,
|
||||
CreateMetaInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteMetaInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -88,6 +85,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
ImageSize,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -103,6 +101,7 @@ from exo.shared.types.api import (
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
normalize_image_size,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -118,10 +117,8 @@ from exo.shared.types.claude_api import (
|
||||
from exo.shared.types.commands import (
|
||||
Command,
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -134,7 +131,7 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -143,7 +140,6 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -282,9 +278,6 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/meta_instances")(self.list_meta_instances)
|
||||
self.app.post("/meta_instance")(self.create_meta_instance)
|
||||
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/models/add")(self.add_custom_model)
|
||||
@@ -314,27 +307,12 @@ class API:
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
model_card = await ModelCard.load(payload.model_id)
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
)
|
||||
|
||||
# Validate placement before sending — fail fast with a clear error
|
||||
# instead of silently dropping the command in the master.
|
||||
try:
|
||||
get_instance_placements(
|
||||
command,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
@@ -546,44 +524,6 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
|
||||
return dict(self.state.meta_instances)
|
||||
|
||||
async def create_meta_instance(
|
||||
self, payload: CreateMetaInstanceParams
|
||||
) -> CreateMetaInstanceResponse:
|
||||
meta_instance = MetaInstance(
|
||||
model_id=payload.model_id,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
node_ids=payload.node_ids,
|
||||
)
|
||||
command = CreateMetaInstance(meta_instance=meta_instance)
|
||||
await self._send(command)
|
||||
return CreateMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
)
|
||||
|
||||
async def delete_meta_instance(
|
||||
self, meta_instance_id: MetaInstanceId
|
||||
) -> DeleteMetaInstanceResponse:
|
||||
meta = self.state.meta_instances.get(meta_instance_id)
|
||||
if not meta:
|
||||
raise HTTPException(status_code=404, detail="MetaInstance not found")
|
||||
|
||||
# Command processor handles cascade-deleting backing instances
|
||||
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return DeleteMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
@@ -603,10 +543,10 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -813,9 +753,11 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -946,10 +888,10 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1032,10 +974,10 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1071,12 +1013,13 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"stream": False,
|
||||
"partial_images": 0,
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1097,7 +1040,7 @@ class API:
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: str,
|
||||
size: ImageSize,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
@@ -1167,7 +1110,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
size: str | None = Form(None),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
@@ -1193,7 +1136,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
size=normalize_image_size(size),
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream_bool,
|
||||
@@ -1229,7 +1172,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
size: str | None = Form(None),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
@@ -1249,7 +1192,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
size=normalize_image_size(size),
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -13,22 +12,11 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.process_managers import ProcessManager
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -48,12 +36,8 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
@@ -76,8 +60,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
@@ -101,16 +84,16 @@ class Master:
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
self._process_managers: Sequence[ProcessManager] = [
|
||||
InstanceHealthReconciler(),
|
||||
NodeTimeoutReconciler(),
|
||||
MetaInstanceReconciler(),
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -119,12 +102,15 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._reconcile)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -306,86 +292,6 @@ class Master:
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateMetaInstance():
|
||||
logger.info(
|
||||
f"Creating MetaInstance for {command.meta_instance.model_id}"
|
||||
f" (min_nodes={command.meta_instance.min_nodes},"
|
||||
f" sharding={command.meta_instance.sharding})"
|
||||
)
|
||||
# Apply immediately so self.state is fresh across
|
||||
# the await below and the reconciler won't race.
|
||||
await self._apply_and_broadcast(
|
||||
MetaInstanceCreated(meta_instance=command.meta_instance)
|
||||
)
|
||||
# Immediate placement attempt for responsiveness
|
||||
model_card = await ModelCard.load(
|
||||
command.meta_instance.model_id
|
||||
)
|
||||
# Re-check: reconciler may have satisfied it during the await
|
||||
meta_id = command.meta_instance.meta_instance_id
|
||||
still_unsatisfied = any(
|
||||
m.meta_instance_id == meta_id
|
||||
for m in find_unsatisfied_meta_instances(
|
||||
self.state.meta_instances,
|
||||
self.state.instances,
|
||||
self.state.topology,
|
||||
)
|
||||
)
|
||||
if still_unsatisfied:
|
||||
result = try_place_for_meta_instance(
|
||||
command.meta_instance,
|
||||
model_card,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
self.state.tasks,
|
||||
)
|
||||
generated_events.extend(result.events)
|
||||
if result.error is not None:
|
||||
generated_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
case DeleteMetaInstance():
|
||||
backing_count = sum(
|
||||
1
|
||||
for inst in self.state.instances.values()
|
||||
if inst.meta_instance_id == command.meta_instance_id
|
||||
)
|
||||
logger.info(
|
||||
f"Deleting MetaInstance {command.meta_instance_id}"
|
||||
f" (cascade-deleting {backing_count} backing instance(s))"
|
||||
)
|
||||
generated_events.append(
|
||||
MetaInstanceDeleted(
|
||||
meta_instance_id=command.meta_instance_id
|
||||
)
|
||||
)
|
||||
# Cascade-delete backing instances atomically,
|
||||
# cancelling any active tasks first.
|
||||
for iid, inst in self.state.instances.items():
|
||||
if inst.meta_instance_id == command.meta_instance_id:
|
||||
for task in self.state.tasks.values():
|
||||
if (
|
||||
task.instance_id == iid
|
||||
and task.task_status
|
||||
in (
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
)
|
||||
):
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
generated_events.append(
|
||||
InstanceDeleted(instance_id=iid)
|
||||
)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
@@ -417,19 +323,16 @@ class Master:
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
command.cancelled_command_id
|
||||
in self.command_task_mapping
|
||||
):
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
del self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -438,10 +341,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
@@ -452,32 +354,31 @@ class Master:
|
||||
):
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
for event in generated_events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
async def _apply_and_broadcast(self, event: Event) -> None:
|
||||
"""Apply event to state, persist to disk, and broadcast to workers.
|
||||
|
||||
State is updated synchronously (before any await), so callers can
|
||||
rely on ``self.state`` reflecting this event immediately after the
|
||||
call. Python's cooperative scheduling guarantees no interleaving
|
||||
between the state read and write.
|
||||
"""
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _reconcile(self) -> None:
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
for pm in self._process_managers:
|
||||
events = await pm.reconcile(self.state)
|
||||
for event in events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await anyio.sleep(1)
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
@@ -495,15 +396,32 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._apply_and_broadcast(event)
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
continue
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
await self._apply_and_broadcast(event)
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
@@ -535,49 +453,10 @@ class Master:
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
await self.event_sender.send(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Sequence
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
@@ -106,27 +106,23 @@ def place_instance(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
largest_rdma_cycles = [
|
||||
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not largest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
largest_cycles = largest_rdma_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
for cycle in largest_cycles
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle),
|
||||
start=Memory(),
|
||||
|
||||
@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_largest_cycles(
|
||||
def get_smallest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
max_nodes = max(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == max_nodes]
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessManager(Protocol):
|
||||
"""A reconciliation step that examines state and returns corrective events."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]: ...
|
||||
@@ -1,62 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
|
||||
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
|
||||
from exo.shared.types.state import State
|
||||
|
||||
MAX_INSTANCE_RETRIES = 3
|
||||
|
||||
|
||||
@final
|
||||
class InstanceHealthReconciler:
|
||||
"""Delete instances whose network connections are broken or whose runners have all failed."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
for instance_id, instance in state.instances.items():
|
||||
if not instance_connections_healthy(instance, state.topology):
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error="Network connection lost",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
is_failed, error_message = instance_runners_failed(
|
||||
instance, state.runners, state.node_identities
|
||||
)
|
||||
if is_failed:
|
||||
# Retry within the same instance if backed by a MetaInstance
|
||||
mid = instance.meta_instance_id
|
||||
mi = state.meta_instances.get(mid) if mid else None
|
||||
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
|
||||
logger.info(
|
||||
f"Instance {instance_id} failed (attempt"
|
||||
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
|
||||
f" retrying: {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceRetrying(
|
||||
instance_id=instance_id,
|
||||
meta_instance_id=mid,
|
||||
failure_error=error_message or "Runner failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if mid and mi:
|
||||
logger.warning(
|
||||
f"Instance {instance_id} exceeded retry limit"
|
||||
f" ({MAX_INSTANCE_RETRIES}), deleting:"
|
||||
f" {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error=error_message,
|
||||
)
|
||||
)
|
||||
return events
|
||||
@@ -1,92 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
|
||||
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstanceReconciler:
|
||||
"""Place instances for unsatisfied MetaInstances."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
all_events: list[Event] = []
|
||||
# Local copy for intermediate tracking — so placement of B
|
||||
# sees A's instance and doesn't double-place on same resources.
|
||||
current_instances: dict[InstanceId, Instance] = dict(state.instances)
|
||||
|
||||
unsatisfied = find_unsatisfied_meta_instances(
|
||||
state.meta_instances,
|
||||
current_instances,
|
||||
state.topology,
|
||||
)
|
||||
for meta_instance in unsatisfied:
|
||||
try:
|
||||
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
|
||||
model_card = await ModelCard.load(meta_instance.model_id)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
|
||||
)
|
||||
error = f"Failed to load model card: {exc}"
|
||||
if meta_instance.placement_error != error:
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=error,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
result = try_place_for_meta_instance(
|
||||
meta_instance,
|
||||
model_card,
|
||||
state.topology,
|
||||
current_instances,
|
||||
state.node_memory,
|
||||
state.node_network,
|
||||
state.tasks,
|
||||
)
|
||||
# Update local instance map so next placement sees this one
|
||||
for event in result.events:
|
||||
if isinstance(event, InstanceCreated):
|
||||
logger.info(
|
||||
f"MetaInstance reconciler placed instance"
|
||||
f" {event.instance.instance_id} for"
|
||||
f" {meta_instance.model_id}"
|
||||
)
|
||||
current_instances[event.instance.instance_id] = event.instance
|
||||
all_events.extend(result.events)
|
||||
|
||||
# Emit placement failure if error differs from what's already in state
|
||||
if (
|
||||
result.error is not None
|
||||
and meta_instance.placement_error != result.error
|
||||
):
|
||||
logger.warning(
|
||||
f"MetaInstance placement failed for"
|
||||
f" {meta_instance.model_id}: {result.error}"
|
||||
)
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
return all_events
|
||||
@@ -1,27 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, NodeTimedOut
|
||||
from exo.shared.types.state import State
|
||||
|
||||
_DEFAULT_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
||||
@final
|
||||
class NodeTimeoutReconciler:
|
||||
"""Time out nodes that haven't been seen recently."""
|
||||
|
||||
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
|
||||
self.timeout = timeout
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
events: list[Event] = []
|
||||
for node_id, last_seen in state.last_seen.items():
|
||||
if now - last_seen > self.timeout:
|
||||
logger.info(f"Removing node {node_id} due to inactivity")
|
||||
events.append(NodeTimedOut(node_id=node_id))
|
||||
return events
|
||||
@@ -1,244 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import get_transition_events, place_instance
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
BaseInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
)
|
||||
|
||||
|
||||
class PlacementResult(NamedTuple):
|
||||
"""Result of a placement attempt: events to apply and optional error reason."""
|
||||
|
||||
events: Sequence[Event]
|
||||
error: str | None
|
||||
|
||||
|
||||
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
|
||||
"""Reconstruct ring order from shard device_rank."""
|
||||
node_ranks: list[tuple[NodeId, int]] = []
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
shard = instance.shard_assignments.runner_to_shard[runner_id]
|
||||
node_ranks.append((node_id, shard.device_rank))
|
||||
node_ranks.sort(key=lambda x: x[1])
|
||||
return [node_id for node_id, _ in node_ranks]
|
||||
|
||||
|
||||
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific IPs used by a ring instance still exist in the topology."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for node in ring:
|
||||
hosts = instance.hosts_by_node[node]
|
||||
for idx in range(n):
|
||||
host = hosts[idx]
|
||||
if host.ip in ("0.0.0.0", "198.51.100.1"):
|
||||
continue # self or placeholder
|
||||
# Real connection: node → ring[idx]. Check specific IP.
|
||||
connections = topology.get_all_connections_between(node, ring[idx])
|
||||
if not any(
|
||||
isinstance(c, SocketConnection)
|
||||
and c.sink_multiaddr.ip_address == host.ip
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
iface = instance.jaccl_devices[i][j]
|
||||
if iface is None:
|
||||
continue
|
||||
connections = topology.get_all_connections_between(ring[i], ring[j])
|
||||
if not any(
|
||||
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
|
||||
"""Check that an instance's nodes and specific connections are still in the topology."""
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
if not all(topology.contains_node(n) for n in instance_nodes):
|
||||
return False
|
||||
if len(instance_nodes) <= 1:
|
||||
return True
|
||||
match instance:
|
||||
case MlxRingInstance():
|
||||
return _ring_connections_healthy(instance, topology)
|
||||
case MlxJacclInstance():
|
||||
return _jaccl_connections_healthy(instance, topology)
|
||||
|
||||
|
||||
def instance_runners_failed(
|
||||
instance: Instance,
|
||||
runners: Mapping[RunnerId, RunnerStatus],
|
||||
node_identities: Mapping[NodeId, NodeIdentity],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if an instance's runners have all reached terminal failure states.
|
||||
|
||||
Returns ``(True, error_message)`` when ALL runners are terminal
|
||||
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
|
||||
|
||||
Returns ``(False, None)`` when runners are still active, haven't reported
|
||||
yet, or all gracefully shut down (no ``RunnerFailed``).
|
||||
"""
|
||||
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
|
||||
|
||||
if not instance_runner_ids:
|
||||
return False, None
|
||||
|
||||
# Build reverse mapping: runner_id -> node_id
|
||||
runner_to_node: dict[RunnerId, NodeId] = {
|
||||
runner_id: node_id
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
}
|
||||
|
||||
has_any_failed = False
|
||||
error_messages: list[str] = []
|
||||
|
||||
for runner_id in instance_runner_ids:
|
||||
status = runners.get(runner_id)
|
||||
if status is None:
|
||||
# Runner hasn't reported yet — instance is still starting
|
||||
return False, None
|
||||
if isinstance(status, RunnerFailed):
|
||||
has_any_failed = True
|
||||
if status.error_message:
|
||||
node_id = runner_to_node.get(runner_id)
|
||||
name = (
|
||||
node_identities[node_id].friendly_name
|
||||
if node_id and node_id in node_identities
|
||||
else node_id or "unknown"
|
||||
)
|
||||
error_messages.append(f"{name}: {status.error_message}")
|
||||
elif isinstance(status, RunnerShutdown):
|
||||
pass # Terminal but not a failure indicator on its own
|
||||
else:
|
||||
# Runner is still active (connecting, loading, running, etc.)
|
||||
return False, None
|
||||
|
||||
if has_any_failed:
|
||||
return True, "; ".join(error_messages) if error_messages else "Runner failed"
|
||||
|
||||
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
|
||||
return False, None
|
||||
|
||||
|
||||
def instance_satisfies_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
instance: Instance,
|
||||
) -> bool:
|
||||
"""Check if a single instance satisfies a meta-instance's constraints.
|
||||
|
||||
This is a pure constraint check (model, min_nodes, node_ids).
|
||||
Use ``instance_connections_healthy`` separately for topology health.
|
||||
"""
|
||||
if instance.shard_assignments.model_id != meta_instance.model_id:
|
||||
return False
|
||||
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
|
||||
if len(instance_nodes) < meta_instance.min_nodes:
|
||||
return False
|
||||
|
||||
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
|
||||
instance_nodes
|
||||
)
|
||||
|
||||
|
||||
def find_unsatisfied_meta_instances(
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
topology: Topology,
|
||||
) -> Sequence[MetaInstance]:
|
||||
"""Return meta-instances that have no healthy backing instance."""
|
||||
unsatisfied: list[MetaInstance] = []
|
||||
for meta_id, meta_instance in meta_instances.items():
|
||||
has_healthy_backing = any(
|
||||
instance.meta_instance_id == meta_id
|
||||
and instance_connections_healthy(instance, topology)
|
||||
for instance in instances.values()
|
||||
)
|
||||
if not has_healthy_backing:
|
||||
unsatisfied.append(meta_instance)
|
||||
return unsatisfied
|
||||
|
||||
|
||||
def try_place_for_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> PlacementResult:
|
||||
"""Try to place an instance satisfying the meta-instance constraints.
|
||||
|
||||
Returns a :class:`PlacementResult` with events on success, or an error
|
||||
reason on failure.
|
||||
"""
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=meta_instance.sharding,
|
||||
instance_meta=meta_instance.instance_meta,
|
||||
min_nodes=meta_instance.min_nodes,
|
||||
)
|
||||
try:
|
||||
target_instances = place_instance(
|
||||
command,
|
||||
topology,
|
||||
current_instances,
|
||||
node_memory,
|
||||
node_network,
|
||||
required_nodes=(
|
||||
set(meta_instance.node_ids) if meta_instance.node_ids else None
|
||||
),
|
||||
)
|
||||
# Tag the new instance with meta_instance_id
|
||||
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
|
||||
if new_instance_ids:
|
||||
new_id = next(iter(new_instance_ids))
|
||||
target_instances[new_id] = target_instances[new_id].model_copy(
|
||||
update={"meta_instance_id": meta_instance.meta_instance_id}
|
||||
)
|
||||
return PlacementResult(
|
||||
events=list(
|
||||
get_transition_events(current_instances, target_instances, tasks)
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
|
||||
)
|
||||
return PlacementResult(events=[], error=str(e))
|
||||
@@ -1,778 +0,0 @@
|
||||
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.process_managers.instance_health import (
|
||||
MAX_INSTANCE_RETRIES,
|
||||
InstanceHealthReconciler,
|
||||
)
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NodeIdentity
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
# --- Helpers (copied from test_reconcile.py for independence) ---
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
consecutive_failures: int = 0,
|
||||
last_failure_error: str | None = None,
|
||||
placement_error: str | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
consecutive_failures=consecutive_failures,
|
||||
last_failure_error=last_failure_error,
|
||||
placement_error=placement_error,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. MetaInstance lifecycle edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_meta_instance_model_is_frozen():
|
||||
"""MetaInstance should be immutable (frozen model)."""
|
||||
meta = _meta_instance()
|
||||
try:
|
||||
meta.model_id = ModelId("something-else")
|
||||
raise AssertionError("Should have raised")
|
||||
except Exception:
|
||||
pass # Expected — frozen model
|
||||
|
||||
|
||||
def test_meta_instance_created_then_deleted_roundtrip():
|
||||
"""Create and delete a MetaInstance through apply — state should be clean."""
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
|
||||
)
|
||||
assert meta.meta_instance_id in state.meta_instances
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
),
|
||||
)
|
||||
assert meta.meta_instance_id not in state.meta_instances
|
||||
assert len(state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_delete_nonexistent_meta_instance_is_safe():
|
||||
"""Deleting a MetaInstance that doesn't exist should not crash."""
|
||||
state = State()
|
||||
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
|
||||
"""MetaInstancePlacementFailed for unknown ID should not crash."""
|
||||
state = State()
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=MetaInstanceId("nonexistent"),
|
||||
reason="test",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_multiple_meta_instances_for_same_model():
|
||||
"""Multiple MetaInstances for the same model are tracked independently."""
|
||||
state = State()
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
|
||||
)
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
|
||||
)
|
||||
assert len(state.meta_instances) == 2
|
||||
assert meta_a.meta_instance_id in state.meta_instances
|
||||
assert meta_b.meta_instance_id in state.meta_instances
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. Retry logic edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_retry_counter_resets_on_successful_instance_creation():
|
||||
"""When a new instance is created for a meta-instance, failures should reset."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
# last_failure_error is preserved (for UI display)
|
||||
assert mi.last_failure_error == "old"
|
||||
|
||||
|
||||
async def test_retry_count_increments_through_full_cycle():
|
||||
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=topology,
|
||||
)
|
||||
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
|
||||
# Simulate runners failing
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
|
||||
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
|
||||
|
||||
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
|
||||
|
||||
# Next failure should result in deletion
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_respects_exact_limit():
|
||||
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_at_limit_minus_one_retries():
|
||||
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. Error handling edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_runners_failed_with_empty_error_message():
|
||||
"""RunnerFailed with empty error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
# Empty error message means we get the fallback
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_with_none_error_message():
|
||||
"""RunnerFailed with None error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message=None)
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_collects_all_error_messages():
|
||||
"""With multiple failed runners, all error messages should be collected."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
|
||||
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
|
||||
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM on GPU 0" in error
|
||||
assert "OOM on GPU 1" in error
|
||||
assert "OOM on GPU 2" in error
|
||||
|
||||
|
||||
def test_runners_failed_includes_friendly_name():
|
||||
"""Error messages should include node friendly names when available."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
node_id = NodeId("node-a")
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
|
||||
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
|
||||
is_failed, error = instance_runners_failed(inst, runners, identities)
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "My Mac Studio" in error
|
||||
|
||||
|
||||
def test_instance_retrying_for_missing_instance_is_safe():
|
||||
"""InstanceRetrying for an instance not in state should not crash.
|
||||
|
||||
NOTE: When the instance is missing, the handler returns early WITHOUT
|
||||
incrementing the MetaInstance failure counter. This means stale retry
|
||||
events for already-deleted instances are silently dropped. This is
|
||||
acceptable since the InstanceDeleted handler already increments failures.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceRetrying(
|
||||
instance_id=InstanceId("nonexistent"),
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Does not crash, but failure count is NOT incremented (early return)
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. Backward compatibility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_instance_without_meta_instance_id_works():
|
||||
"""Instances created without meta_instance_id should still function normally."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert inst.meta_instance_id is None
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
|
||||
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0 # unchanged
|
||||
|
||||
|
||||
def test_satisfies_ignores_meta_instance_id_binding():
|
||||
"""instance_satisfies_meta_instance checks constraints only, not binding."""
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
|
||||
# Should match on constraints (model, min_nodes) regardless of binding
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_find_unsatisfied_uses_binding_not_constraints():
|
||||
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
|
||||
meta = _meta_instance()
|
||||
# Instance matches constraints but is NOT bound to this meta_instance
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {iid: inst}, topology
|
||||
)
|
||||
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 5. Concurrent / multi-instance scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_health_reconciler_handles_multiple_failing_instances():
|
||||
"""Multiple instances failing simultaneously should each get their own event."""
|
||||
meta_a = _meta_instance()
|
||||
meta_b = _meta_instance()
|
||||
iid_a, inst_a = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
iid_b, inst_b = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
|
||||
)
|
||||
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
|
||||
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
instances={iid_a: inst_a, iid_b: inst_b},
|
||||
runners={
|
||||
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 2
|
||||
# Both should be InstanceRetrying since failures < MAX
|
||||
assert all(isinstance(e, InstanceRetrying) for e in events)
|
||||
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
|
||||
assert instance_ids == {iid_a, iid_b}
|
||||
|
||||
|
||||
async def test_health_reconciler_mixed_healthy_and_failing():
|
||||
"""Only failing instances should produce events; healthy ones should not."""
|
||||
meta_healthy = _meta_instance()
|
||||
meta_failing = _meta_instance()
|
||||
iid_h, inst_h = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
|
||||
)
|
||||
iid_f, inst_f = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
|
||||
)
|
||||
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
|
||||
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_healthy.meta_instance_id: meta_healthy,
|
||||
meta_failing.meta_instance_id: meta_failing,
|
||||
},
|
||||
instances={iid_h: inst_h, iid_f: inst_f},
|
||||
runners={
|
||||
runner_ids_h[0]: RunnerReady(),
|
||||
runner_ids_f[0]: RunnerFailed(error_message="crash"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid_f
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_empty_state():
|
||||
"""MetaInstanceReconciler with no meta_instances should produce no events."""
|
||||
state = State()
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 6. Placement error tracking
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_placement_failed_sets_error():
|
||||
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="Not enough memory",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error == "Not enough memory"
|
||||
|
||||
|
||||
def test_instance_created_clears_placement_error():
|
||||
"""InstanceCreated should clear placement_error on the MetaInstance."""
|
||||
meta = _meta_instance(placement_error="Not enough memory")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
def test_placement_error_does_not_increment_failures():
|
||||
"""Placement failures should only set placement_error, not increment consecutive_failures."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="No resources",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.placement_error == "No resources"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. State serialization roundtrip
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_state_with_meta_instances_serializes():
|
||||
"""State with meta_instances should serialize and deserialize correctly."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
json_str = state.model_dump_json()
|
||||
restored = State.model_validate_json(json_str)
|
||||
assert meta.meta_instance_id in restored.meta_instances
|
||||
mi = restored.meta_instances[meta.meta_instance_id]
|
||||
assert mi.model_id == meta.model_id
|
||||
assert mi.consecutive_failures == 2
|
||||
assert mi.last_failure_error == "test"
|
||||
assert iid in restored.instances
|
||||
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. MetaInstanceReconciler error handling
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance()
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 1
|
||||
assert "Failed to load model card" in placement_failed[0].reason
|
||||
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance(placement_error="Failed to load model card: Network error")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Error matches existing placement_error, so no duplicate event emitted
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_continues_after_error(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""Reconciler should continue to next meta-instance after one fails to load."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta_a = _meta_instance(model_id="org/model-a")
|
||||
meta_b = _meta_instance(model_id="org/model-b")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _load_second_fails(model_id: ModelId) -> ModelCard:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError(f"Cannot load {model_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Both meta-instances should have been attempted (not short-circuited)
|
||||
assert call_count == 2
|
||||
# Both should have placement failed events
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. Cascade delete with task cancellation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_cascade_delete_cancels_active_tasks():
|
||||
"""Deleting a MetaInstance should cancel tasks on backing instances.
|
||||
|
||||
Regression test: previously, cascade-deleting backing instances via
|
||||
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
|
||||
tasks, leaving orphaned task references in state.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
task_id = TaskId()
|
||||
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
|
||||
|
||||
# Build state with meta-instance, backing instance, and active task
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={task_id: task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Simulate the cascade-delete event sequence produced by main.py:
|
||||
# 1. MetaInstanceDeleted
|
||||
# 2. TaskStatusUpdated(Cancelled) for active tasks
|
||||
# 3. InstanceDeleted
|
||||
idx = 0
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
|
||||
)
|
||||
|
||||
# Verify everything is cleaned up
|
||||
assert len(state.meta_instances) == 0
|
||||
assert len(state.instances) == 0
|
||||
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
|
||||
|
||||
|
||||
def test_cascade_delete_skips_completed_tasks():
|
||||
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
|
||||
running_task_id = TaskId()
|
||||
completed_task_id = TaskId()
|
||||
running_task = LoadModel(
|
||||
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
|
||||
)
|
||||
completed_task = LoadModel(
|
||||
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
|
||||
)
|
||||
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={running_task_id: running_task, completed_task_id: completed_task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Only the running task should be cancelled — we verify the logic pattern
|
||||
# by checking which tasks are Pending or Running
|
||||
active_tasks = [
|
||||
t
|
||||
for t in state.tasks.values()
|
||||
if t.instance_id == iid
|
||||
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(active_tasks) == 1
|
||||
assert active_tasks[0].task_id == running_task_id
|
||||
@@ -3,10 +3,10 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_largest_cycles():
|
||||
def test_get_smallest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
@@ -175,12 +175,12 @@ def test_get_largest_cycles():
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
|
||||
# act
|
||||
largest_cycles = get_largest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(largest_cycles) == 1
|
||||
assert len(largest_cycles[0]) == 3
|
||||
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,742 +0,0 @@
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
|
||||
|
||||
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
|
||||
between consecutive nodes (including wrap-around).
|
||||
"""
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
# Build hosts_by_node with IPs matching _topology() convention:
|
||||
# node at index idx has IP 10.0.0.{idx+1}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# --- instance_satisfies_meta_instance (pure constraint matching) ---
|
||||
|
||||
|
||||
def test_satisfies_matching_model():
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_not_satisfies_wrong_model():
|
||||
meta = _meta_instance("test-org/model-a")
|
||||
_, inst = _instance("test-org/model-b")
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_missing_required_node():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-c")])
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_fewer_than_min_nodes():
|
||||
meta = _meta_instance(min_nodes=3)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_satisfies_with_node_ids_specified():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
# --- instance_connections_healthy ---
|
||||
|
||||
|
||||
def test_healthy_single_node_present():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_single_node_missing():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = Topology() # empty
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_two_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_two_node_edge_removed():
|
||||
"""Nodes present but edge removed — ring broken."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", connect=False)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_two_node_ip_changed():
|
||||
"""Edge exists but with a different IP than instance was configured with."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
# Build topology with different IPs than _instance() expects
|
||||
topology = Topology()
|
||||
topology.add_node(NodeId("node-a"))
|
||||
topology.add_node(NodeId("node-b"))
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-a"),
|
||||
sink=NodeId("node-b"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-b"),
|
||||
sink=NodeId("node-a"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_three_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_three_node_one_edge_removed():
|
||||
"""Remove one edge from a three-node ring — instance unhealthy."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
# Build topology with one direction of one edge missing
|
||||
topology = Topology()
|
||||
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
|
||||
for n in nodes:
|
||||
topology.add_node(n)
|
||||
# Add all edges except node-a → node-b
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[1],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[0],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
# Missing: node-a → node-b (ip 10.0.0.2)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_node_missing_from_topology():
|
||||
"""Instance has a node that's not in the topology at all."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a") # node-b not present
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_extra_nodes_in_topology():
|
||||
"""Extra nodes in topology don't affect instance health."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
# --- find_unsatisfied_meta_instances ---
|
||||
|
||||
|
||||
def test_unsatisfied_no_meta_instances():
|
||||
result = find_unsatisfied_meta_instances({}, {}, Topology())
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_satisfied():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_not_satisfied():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
id_a, inst_a = _instance("test-org/model-y")
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_mix():
|
||||
meta_satisfied = _meta_instance("test-org/model-a")
|
||||
meta_unsatisfied = _meta_instance("test-org/model-b")
|
||||
id_a, inst_a = _instance(
|
||||
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_satisfied.meta_instance_id: meta_satisfied,
|
||||
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
|
||||
},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_unsatisfied]
|
||||
|
||||
|
||||
def test_unsatisfied_node_disconnect():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a") # node-b disconnected
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_edge_break():
|
||||
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_idempotent():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
topology = _topology("node-a")
|
||||
meta_instances = {meta.meta_instance_id: meta}
|
||||
instances: dict[InstanceId, MlxRingInstance] = {}
|
||||
result_1 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
result_2 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
assert result_1 == result_2
|
||||
|
||||
|
||||
def test_unsatisfied_exclusive_binding():
|
||||
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
id_inst, inst = _instance(
|
||||
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
{id_inst: inst},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_b]
|
||||
|
||||
|
||||
# --- apply handlers ---
|
||||
|
||||
|
||||
def test_apply_meta_instance_created():
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
event = MetaInstanceCreated(meta_instance=meta)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id in new_state.meta_instances
|
||||
assert new_state.meta_instances[meta.meta_instance_id] == meta
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted():
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted_clears_failure_info():
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
|
||||
)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
# --- instance_runners_failed ---
|
||||
|
||||
|
||||
def test_runners_failed_all_failed():
|
||||
"""All runners in RunnerFailed -> instance is failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="OOM")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM" in error
|
||||
|
||||
|
||||
def test_runners_failed_mixed_failed_shutdown():
|
||||
"""One Failed + one Shutdown = failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="crash"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "crash" in error
|
||||
|
||||
|
||||
def test_runners_not_failed_all_shutdown():
|
||||
"""All Shutdown (graceful) = not a failure."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_still_active():
|
||||
"""Some runners still active = not failed yet."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerLoading(),
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_no_status():
|
||||
"""Runner not yet reported = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
is_failed, _ = instance_runners_failed(inst, {}, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_healthy():
|
||||
"""Runners in Ready state = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
# --- failure tracking in apply_instance_deleted ---
|
||||
|
||||
|
||||
def test_apply_instance_deleted_tracks_failure():
|
||||
"""InstanceDeleted with failure_error increments meta instance failure count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "Runner OOM"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_increments_failure():
|
||||
"""Subsequent failures increment the counter."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
|
||||
)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="new error")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 3
|
||||
assert mi.last_failure_error == "new error"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_no_failure_no_tracking():
|
||||
"""InstanceDeleted without failure_error does not track."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_apply_instance_deleted_orphan_no_tracking():
|
||||
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
state = State(instances={iid: inst})
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
# --- InstanceRetrying ---
|
||||
|
||||
|
||||
def test_apply_instance_retrying_removes_runners():
|
||||
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners=runners,
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="OOM",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Instance still exists
|
||||
assert iid in new_state.instances
|
||||
# Runners removed
|
||||
assert runner_ids[0] not in new_state.runners
|
||||
assert runner_ids[1] not in new_state.runners
|
||||
|
||||
|
||||
def test_apply_instance_retrying_increments_failure():
|
||||
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "crash"
|
||||
|
||||
|
||||
def test_apply_instance_retrying_skips_missing_runners():
|
||||
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
# No runners in state at all
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
# Should not raise
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert iid in new_state.instances
|
||||
|
||||
|
||||
def test_apply_instance_created_resets_failure_counter():
|
||||
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 3, "last_failure_error": "old error"}
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceCreated(instance=inst)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.last_failure_error == "old error"
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
# --- InstanceHealthReconciler retry-vs-delete ---
|
||||
|
||||
|
||||
async def test_health_reconciler_retries_when_under_limit():
|
||||
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid
|
||||
assert events[0].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_when_limit_reached():
|
||||
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
|
||||
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_without_meta_instance():
|
||||
"""Instances without a MetaInstance are deleted immediately on runner failure."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_network_failure_always_deletes():
|
||||
"""Network failure always triggers InstanceDeleted regardless of retry count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=_topology("node-a"), # node-b missing
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].failure_error == "Network connection lost"
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,12 +12,6 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -34,7 +28,6 @@ from exo.shared.types.events import (
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
@@ -73,22 +66,12 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceRetrying():
|
||||
return apply_instance_retrying(event, state)
|
||||
case MetaInstanceCreated():
|
||||
return apply_meta_instance_created(event, state)
|
||||
case MetaInstanceDeleted():
|
||||
return apply_meta_instance_deleted(event, state)
|
||||
case MetaInstancePlacementFailed():
|
||||
return apply_meta_instance_placement_failed(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -191,123 +174,20 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
def _update_meta_instance(
|
||||
state: State, mid: MetaInstanceId, **fields: object
|
||||
) -> Mapping[MetaInstanceId, MetaInstance]:
|
||||
mi = state.meta_instances[mid]
|
||||
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
|
||||
|
||||
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
instance.instance_id: instance,
|
||||
}
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
# Reset failure tracking when a new instance is created for a meta-instance
|
||||
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
|
||||
mi = state.meta_instances[instance.meta_instance_id]
|
||||
if mi.placement_error is not None or mi.consecutive_failures > 0:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
instance.meta_instance_id,
|
||||
placement_error=None,
|
||||
consecutive_failures=0,
|
||||
)
|
||||
return state.model_copy(update=update)
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
|
||||
# Track failure on the MetaInstance itself
|
||||
if (
|
||||
event.failure_error
|
||||
and deleted_instance
|
||||
and deleted_instance.meta_instance_id
|
||||
and deleted_instance.meta_instance_id in state.meta_instances
|
||||
):
|
||||
mid = deleted_instance.meta_instance_id
|
||||
mi = state.meta_instances[mid]
|
||||
update["meta_instances"] = {
|
||||
**state.meta_instances,
|
||||
mid: mi.model_copy(
|
||||
update={
|
||||
"consecutive_failures": mi.consecutive_failures + 1,
|
||||
"last_failure_error": event.failure_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
|
||||
"""Runners failed but retry limit not reached — remove runners, keep instance."""
|
||||
instance = state.instances.get(event.instance_id)
|
||||
if instance is None:
|
||||
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
|
||||
# The InstanceDeleted handler already incremented consecutive_failures
|
||||
# on the MetaInstance, so skipping here avoids double-counting.
|
||||
return state
|
||||
|
||||
# Remove all runners belonging to this instance from state
|
||||
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
|
||||
update: dict[str, object] = {"runners": new_runners}
|
||||
|
||||
# Increment failure count on the MetaInstance
|
||||
if event.meta_instance_id in state.meta_instances:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
event.meta_instance_id,
|
||||
consecutive_failures=state.meta_instances[
|
||||
event.meta_instance_id
|
||||
].consecutive_failures
|
||||
+ 1,
|
||||
last_failure_error=event.failure_error,
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
**state.meta_instances,
|
||||
event.meta_instance.meta_instance_id: event.meta_instance,
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
mid: mi
|
||||
for mid, mi in state.meta_instances.items()
|
||||
if mid != event.meta_instance_id
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_placement_failed(
|
||||
event: MetaInstancePlacementFailed, state: State
|
||||
) -> State:
|
||||
if event.meta_instance_id not in state.meta_instances:
|
||||
return state
|
||||
return state.model_copy(
|
||||
update={
|
||||
"meta_instances": _update_meta_instance(
|
||||
state, event.meta_instance_id, placement_error=event.reason
|
||||
)
|
||||
}
|
||||
)
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
|
||||
@@ -44,7 +44,8 @@ async def _refresh_card_cache():
|
||||
async for toml_file in path.rglob("*.toml"):
|
||||
try:
|
||||
card = await ModelCard.load_from_path(toml_file)
|
||||
_card_cache[card.model_id] = card
|
||||
if card.model_id not in _card_cache:
|
||||
_card_cache[card.model_id] = card
|
||||
except (ValidationError, TOMLKitError):
|
||||
pass
|
||||
|
||||
@@ -182,6 +183,7 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, get_args
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -262,24 +262,25 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
ImageSize = Literal[
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
]
|
||||
|
||||
|
||||
class CreateMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class DeleteMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
def normalize_image_size(v: object) -> ImageSize:
|
||||
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
|
||||
if v is None:
|
||||
return "auto"
|
||||
if v not in get_args(ImageSize):
|
||||
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
|
||||
return v # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
@@ -301,7 +302,7 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
size: ImageSize = "auto"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
@@ -309,6 +310,11 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
@@ -325,13 +331,18 @@ class ImageEditsTaskParams(BaseModel):
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
size: ImageSize = "auto"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
if name == "image_data":
|
||||
|
||||
@@ -6,8 +6,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -53,14 +52,6 @@ class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class CreateMetaInstance(BaseCommand):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class DeleteMetaInstance(BaseCommand):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -103,8 +94,6 @@ Command = (
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| CreateMetaInstance
|
||||
| DeleteMetaInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -42,10 +42,6 @@ class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class MetaInstanceId(Id):
|
||||
"""Identifier for a MetaInstance."""
|
||||
|
||||
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Annotated, final
|
||||
from typing import final
|
||||
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -17,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -91,30 +66,6 @@ class InstanceCreated(BaseEvent):
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
failure_error: str | None = None
|
||||
|
||||
|
||||
class MetaInstanceCreated(BaseEvent):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class MetaInstanceDeleted(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstancePlacementFailed(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
reason: str
|
||||
|
||||
|
||||
@final
|
||||
class InstanceRetrying(BaseEvent):
|
||||
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
meta_instance_id: MetaInstanceId
|
||||
failure_error: str
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
@@ -181,25 +132,6 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -209,10 +141,6 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceRetrying
|
||||
| MetaInstanceCreated
|
||||
| MetaInstanceDeleted
|
||||
| MetaInstancePlacementFailed
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
@@ -224,8 +152,6 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstance(FrozenModel):
|
||||
"""Declarative constraint: ensure an instance matching these parameters always exists."""
|
||||
|
||||
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
# Failure tracking
|
||||
placement_error: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
last_failure_error: str | None = None
|
||||
@@ -6,8 +6,7 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
@@ -42,7 +41,6 @@ class State(CamelCaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,7 +19,6 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
meta_instance_id: MetaInstanceId | None = None
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -125,7 +125,9 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -14,6 +14,7 @@ from exo.shared.types.api import (
|
||||
ImageEditsTaskParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ImageSize,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
@@ -23,9 +24,9 @@ from exo.shared.types.worker.runner_response import (
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
def parse_size(size_str: ImageSize) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str:
|
||||
if size_str == "auto":
|
||||
return (1024, 1024)
|
||||
|
||||
try:
|
||||
@@ -109,6 +110,9 @@ def generate_image(
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
if task.size == "auto":
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
|
||||
@@ -163,11 +163,14 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
|
||||
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
|
||||
if self.is_prefill:
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(cache.keys) # type: ignore
|
||||
mx.eval(_cache.keys) # type: ignore
|
||||
|
||||
if not self.is_prefill:
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
@@ -307,7 +310,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
last = cache[-1] # type: ignore
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
|
||||
@@ -333,7 +338,9 @@ def patch_tensor_model[T](model: T) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
last = cache[-1] # pyright: ignore[reportAny]
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
@@ -547,10 +554,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
@@ -581,12 +590,18 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
# Shard the MoE.
|
||||
else:
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -779,8 +794,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
# Shard the MoE.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
@@ -893,8 +907,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
# Shard the MoE.
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
@@ -57,6 +57,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -86,6 +87,9 @@ def prefill(
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
@@ -305,16 +309,9 @@ def mlx_generate(
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
@@ -331,6 +328,7 @@ def mlx_generate(
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
|
||||
logger.info("Starting decode")
|
||||
mx_barrier(group)
|
||||
|
||||
for completion_tokens, out in enumerate(
|
||||
|
||||
@@ -285,10 +285,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# For GLM-5 and GLM-4.7
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
@@ -574,11 +576,6 @@ def mlx_cleanup(
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
|
||||
@@ -24,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -34,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -159,15 +159,6 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
@@ -234,15 +225,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -280,18 +278,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
await self._start_runner_task(modified_task)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
|
||||
@@ -35,7 +35,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
@@ -57,7 +56,7 @@ def plan(
|
||||
return (
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances, all_runners)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
@@ -76,12 +75,6 @@ def _kill_runner(
|
||||
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
# Master removed our runner from state (retry signal) and process is dead
|
||||
if runner_id not in all_runners and isinstance(
|
||||
runner.status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
for (
|
||||
global_runner_id
|
||||
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
|
||||
@@ -99,7 +92,6 @@ def _create_runner(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> CreateRunner | None:
|
||||
for instance in instances.values():
|
||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||
@@ -109,16 +101,6 @@ def _create_runner(
|
||||
if runner_id in runners:
|
||||
continue
|
||||
|
||||
# Don't create while any peer runner is in a terminal state — wait for
|
||||
# the master to emit InstanceRetrying which removes them from state.
|
||||
has_terminal_peer = any(
|
||||
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
|
||||
for peer_rid in instance.shard_assignments.node_to_runner.values()
|
||||
if peer_rid != runner_id
|
||||
)
|
||||
if has_terminal_peer:
|
||||
continue
|
||||
|
||||
shard = instance.shard(runner_id)
|
||||
assert shard is not None
|
||||
|
||||
@@ -328,8 +310,7 @@ def _pending_tasks(
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> CancelTask | None:
|
||||
"""Find a cancelled task that hasn't been sent to the runner yet."""
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
|
||||
@@ -17,7 +17,6 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -31,16 +30,6 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -67,9 +56,7 @@ def entrypoint(
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
cancel_receiver.close()
|
||||
finally:
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
cancel_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -243,7 +243,7 @@ def main(
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -251,7 +251,7 @@ def main(
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -18,14 +14,12 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -40,26 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -72,21 +46,12 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_event_sender: Sender[Event]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -100,23 +65,6 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
@@ -125,7 +73,6 @@ class RunnerSupervisor:
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -141,54 +88,21 @@ class RunnerSupervisor:
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self._event_sender.close()
|
||||
self._close_pipe_fds()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -226,7 +140,6 @@ class RunnerSupervisor:
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
@@ -268,110 +181,6 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
@@ -382,7 +191,7 @@ class RunnerSupervisor:
|
||||
logger.info("Checking runner's status")
|
||||
if self.runner_process.is_alive():
|
||||
logger.info("Runner was found to be alive, attempting to join process")
|
||||
await to_thread.run_sync(self.runner_process.join, 1)
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
rc = self.runner_process.exitcode
|
||||
logger.info(f"RunnerSupervisor exited with exit code {rc}")
|
||||
if rc == 0:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user