Compare commits

..

6 Commits

Author SHA1 Message Date
Evan
b0da9dd56b runner opts 2026-02-26 17:51:31 +00:00
rltakashige
152a27ea5d Fix pipeline mismatched send after 1587 (#1629)
## Motivation

Tests caught a bug. It was a real bug.
2026-02-26 16:48:34 +00:00
rltakashige
db36bd5ac6 Add custom prefill for pipeline (#1587)
## Motivation

Since we need to do distributed communications between prefill step
sizes, the out-of-the-box stream_generate that we currently use prevents
pipeline parallel models from doing overlapped computation. While this
was technically a regression, this communication is necessary for
cancellation, and we will need various distributed communications in the
future (e.g. for coordinating batching).

500 lines are for one testing file, so the diffs aren't as bad as they
look!

## Changes

Added a special prefill function for pipeline parallel models
Edited the model to handle 
Added a test to verify this new prefill and the original prefill produce
identical results
Improved type stubs to remove some type: ignores 

## Why It Works
<img width="768" height="1246" alt="image"
src="https://github.com/user-attachments/assets/8986ff17-ac23-4a02-9bd7-e6253a0ca799"
/>

## Test Plan

### Manual Testing
Needs more testing, but seems good so far.

### Automated Testing
Passes CI, considerable speedup seen in benchmarks (up to 1.98x) on
prefill speed.

Before:
<img width="3280" height="1238" alt="image"
src="https://github.com/user-attachments/assets/9abc1cbc-ecdb-4e48-a675-2c4cb04a32a0"
/>


After:
<img width="3344" height="1236" alt="image"
src="https://github.com/user-attachments/assets/e03c7987-41b4-4950-9ac3-2840e774ce30"
/>
2026-02-26 16:00:38 +00:00
Evan Quiney
639243aa09 event router (#1572)
replace the nack & resend logic in the worker/download coordinator with
a dedicated subsystem in front of the topic router. this centralizes
that logic (and the concept of system ids) to reduce replication in the
codebase. each system reading or writing events now gets a clean stream
of events in and can trust written events will be retried until
acknowledged.
2026-02-26 14:17:02 +00:00
Evan Quiney
db73c4fd5d move messaging into rust (#1549)
the main body of the rust refactor. fixes the tokio panic on shutdown.
simplifies the networking module significantly. doesn't touch lp2p
behaviour
2026-02-26 13:58:22 +00:00
ciaranbor
eaed92952c Use tmpdir for coordination file (#1624)
## Motivation

Coordination files for MLX distributed init were written to the current
working directory (./hosts_*.json)

## Changes

- Move coordination file creation to a tempfile.TemporaryDirectory(),
which auto-cleans on context manager exit
2026-02-26 10:59:36 +00:00
46 changed files with 1650 additions and 1775 deletions

View File

@@ -73,9 +73,11 @@ class GenerationResponse:
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
prompt_cache: Any,
quantized_kv_start: int | None,
kv_group_size: int | None,
kv_bits: int | None,
) -> None: ...
def generate_step(
prompt: mx.array,
model: nn.Module,

View File

@@ -16,7 +16,7 @@ class Cache(Protocol):
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@property
def meta_state(self) -> Literal[""]: ...
@meta_state.setter
def meta_state(self, v) -> None: ...
def trim(self, n: int) -> int: ...
def is_trimmable(self) -> Literal[False]: ...
@classmethod
def from_state(cls, state, meta_state) -> Self: ...
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
@property
def state(
self,
) -> tuple[array, array]: ...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
@property
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
...
def __getitem__(self, idx): ...
@property
def state(self): # -> list[Any | array] | list[array]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@@ -249,15 +239,13 @@ class ChunkedKVCache(KVCache):
...
class CacheList(_BaseCache):
caches: tuple[_BaseCache, ...]
def __init__(self, *caches: _BaseCache) -> None: ...
def __init__(self, *caches) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
def trim(self, n: int) -> int: ...
@property
def state(self): # -> list[Any]:
...
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
@state.setter
def state(self, v): # -> None:
...

24
Cargo.lock generated
View File

@@ -216,6 +216,28 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "async-stream"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -2759,6 +2781,7 @@ dependencies = [
name = "networking"
version = "0.0.1"
dependencies = [
"async-stream",
"delegate",
"either",
"extend",
@@ -2767,6 +2790,7 @@ dependencies = [
"keccak-const",
"libp2p",
"log",
"pin-project",
"tokio",
"tracing-subscriber",
"util",

View File

@@ -34,6 +34,7 @@ delegate = "0.13"
keccak-const = "0.2"
# Async dependencies
async-stream = "0.3"
tokio = "1.46"
futures-lite = "2.6.1"
futures-timer = "3.0"

View File

@@ -170,30 +170,5 @@
{/if}
Downloads
</a>
<a
href="/#/settings"
class="text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
title="Settings"
>
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
/>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
/>
</svg>
Settings
</a>
</nav>
</header>

View File

@@ -1,87 +0,0 @@
/**
* SettingsStore - Manages exo runtime settings via the /settings API.
*/
export interface MemorySettings {
oom_prevention: boolean;
memory_threshold: number;
memory_floor_gb: number;
}
export interface GenerationSettings {
prefill_step_size: number;
max_tokens: number;
kv_cache_bits: 4 | 8 | null;
}
export interface ExoSettings {
memory: MemorySettings;
generation: GenerationSettings;
}
function defaultSettings(): ExoSettings {
return {
memory: {
oom_prevention: false,
memory_threshold: 0.8,
memory_floor_gb: 5.0,
},
generation: {
prefill_step_size: 4096,
max_tokens: 32168,
kv_cache_bits: null,
},
};
}
class SettingsStore {
settings = $state<ExoSettings>(defaultSettings());
loading = $state(false);
error = $state<string | null>(null);
async load(): Promise<void> {
this.loading = true;
this.error = null;
try {
const response = await fetch("/settings");
if (!response.ok) {
throw new Error(`Failed to fetch settings: ${response.status}`);
}
this.settings = (await response.json()) as ExoSettings;
} catch (err) {
console.error("Failed to load settings:", err);
this.error = err instanceof Error ? err.message : "Unknown error";
} finally {
this.loading = false;
}
}
async save(updated: ExoSettings): Promise<boolean> {
this.loading = true;
this.error = null;
try {
const response = await fetch("/settings", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(updated),
});
if (!response.ok) {
throw new Error(`Failed to save settings: ${response.status}`);
}
this.settings = (await response.json()) as ExoSettings;
return true;
} catch (err) {
console.error("Failed to save settings:", err);
this.error = err instanceof Error ? err.message : "Unknown error";
return false;
} finally {
this.loading = false;
}
}
resetToDefaults(): ExoSettings {
return defaultSettings();
}
}
export const settingsStore = new SettingsStore();

View File

@@ -1,193 +0,0 @@
<script lang="ts">
import { onMount } from "svelte";
import { fade } from "svelte/transition";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import { settingsStore, type ExoSettings } from "$lib/stores/settings.svelte";
import { addToast } from "$lib/stores/toast.svelte";
let draft = $state<ExoSettings | null>(null);
const loading = $derived(settingsStore.loading);
onMount(async () => {
await settingsStore.load();
draft = structuredClone(settingsStore.settings);
});
async function handleSave() {
if (!draft) return;
const ok = await settingsStore.save(draft);
if (ok) {
addToast({ type: "success", message: "Settings saved" });
} else {
addToast({ type: "error", message: settingsStore.error ?? "Failed to save settings" });
}
}
function handleReset() {
draft = settingsStore.resetToDefaults();
}
const KV_OPTIONS: { label: string; value: 4 | 8 | null }[] = [
{ label: "None (full precision)", value: null },
{ label: "4-bit", value: 4 },
{ label: "8-bit", value: 8 },
];
</script>
<HeaderNav showHome={true} />
{#if draft}
<div class="min-h-screen bg-background text-foreground" in:fade={{ duration: 200 }}>
<div class="max-w-2xl mx-auto px-6 py-8">
<h1 class="text-2xl font-bold text-exo-yellow tracking-wider uppercase mb-8">Settings</h1>
<!-- Memory / Safety -->
<section class="mb-10">
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Memory / Safety</h2>
<div class="space-y-5">
<!-- OOM Prevention Toggle -->
<div class="flex items-center justify-between">
<div>
<div class="text-sm text-white/90">OOM Prevention</div>
<div class="text-xs text-white/40 mt-0.5">Stop generation when memory is low</div>
</div>
<button
onclick={() => { if (draft) draft.memory.oom_prevention = !draft.memory.oom_prevention; }}
class="relative w-11 h-6 rounded-full transition-colors duration-200 cursor-pointer {draft.memory.oom_prevention ? 'bg-exo-yellow' : 'bg-exo-medium-gray'}"
role="switch"
aria-checked={draft.memory.oom_prevention}
>
<span
class="absolute top-0.5 left-0.5 w-5 h-5 rounded-full bg-white shadow transition-transform duration-200 {draft.memory.oom_prevention ? 'translate-x-5' : 'translate-x-0'}"
></span>
</button>
</div>
<!-- Memory Threshold Slider -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Memory Threshold</div>
<div class="text-xs text-white/40 mt-0.5">KV cache eviction triggers above this level</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{(draft.memory.memory_threshold * 100).toFixed(0)}%</span>
</div>
<input
type="range"
min="0.5"
max="0.99"
step="0.01"
bind:value={draft.memory.memory_threshold}
class="w-full h-1.5 rounded-full appearance-none cursor-pointer bg-exo-medium-gray accent-exo-yellow"
/>
</div>
<!-- Memory Floor -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Memory Floor</div>
<div class="text-xs text-white/40 mt-0.5">Minimum free memory to reserve (GB)</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.memory.memory_floor_gb.toFixed(1)} GB</span>
</div>
<input
type="number"
min="0"
max="64"
step="0.5"
bind:value={draft.memory.memory_floor_gb}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
</div>
</section>
<!-- Generation / Performance -->
<section class="mb-10">
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Generation / Performance</h2>
<div class="space-y-5">
<!-- Prefill Step Size -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Prefill Step Size</div>
<div class="text-xs text-white/40 mt-0.5">Token chunk size during prompt processing</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.generation.prefill_step_size.toLocaleString()}</span>
</div>
<input
type="number"
min="128"
max="32768"
step="128"
bind:value={draft.generation.prefill_step_size}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<!-- Max Tokens -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Max Tokens</div>
<div class="text-xs text-white/40 mt-0.5">Maximum generation length per response</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.generation.max_tokens.toLocaleString()}</span>
</div>
<input
type="number"
min="1"
max="131072"
step="1024"
bind:value={draft.generation.max_tokens}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<!-- KV Cache Bits -->
<div>
<div class="mb-1.5">
<div class="text-sm text-white/90">KV Cache Quantization</div>
<div class="text-xs text-white/40 mt-0.5">Lower bits save memory at slight quality cost</div>
</div>
<select
bind:value={draft.generation.kv_cache_bits}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50 cursor-pointer"
>
{#each KV_OPTIONS as opt}
<option value={opt.value}>{opt.label}</option>
{/each}
</select>
</div>
</div>
</section>
<!-- Action Buttons -->
<div class="flex items-center gap-3">
<button
onclick={handleSave}
disabled={loading}
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker
disabled:opacity-50 disabled:cursor-not-allowed"
>
{loading ? "Saving..." : "Save"}
</button>
<button
onclick={handleReset}
disabled={loading}
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
border border-exo-light-gray/30 text-white/70 hover:border-exo-yellow/50 hover:text-exo-yellow
disabled:opacity-50 disabled:cursor-not-allowed"
>
Reset to Defaults
</button>
</div>
</div>
</div>
{:else}
<div class="min-h-screen bg-background flex items-center justify-center">
<div class="text-white/40 text-sm">Loading settings...</div>
</div>
{/if}

View File

@@ -2,7 +2,6 @@
# ruff: noqa: E501, F401
import builtins
import enum
import typing
@typing.final
@@ -11,29 +10,6 @@ class AllQueuesFullError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
@@ -58,21 +34,15 @@ class Keypair:
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
"""
@typing.final
class MessageTooLargeError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
@@ -91,24 +61,7 @@ class NetworkingHandle:
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final
class MessageTooLargeError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
async def recv(self) -> PyFromSwarm: ...
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
@@ -116,11 +69,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...
class PyFromSwarm:
@typing.final
class Connection(PyFromSwarm):
__match_args__ = ("peer_id", "connected",)
@property
def peer_id(self) -> builtins.str: ...
@property
def connected(self) -> builtins.bool: ...
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
@typing.final
class Message(PyFromSwarm):
__match_args__ = ("origin", "topic", "data",)
@property
def origin(self) -> builtins.str: ...
@property
def topic(self) -> builtins.str: ...
@property
def data(self) -> bytes: ...
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
...

View File

@@ -4,11 +4,12 @@ build-backend = "maturin"
[project]
name = "exo_pyo3_bindings"
version = "0.1.0"
version = "0.2.0"
description = "Add your description here"
readme = "README.md"
authors = [
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
]
requires-python = ">=3.13"
dependencies = []

View File

@@ -155,6 +155,9 @@ pub(crate) mod ext {
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
pyo3_async_runtimes::tokio::init(builder);
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without

View File

@@ -1,26 +1,24 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use std::pin::Pin;
use std::sync::Arc;
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::networking::exception::{
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
};
use crate::pyclass;
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use networking::discovery;
use networking::swarm::create_swarm;
use futures_lite::{Stream, StreamExt as _};
use libp2p::gossipsub::PublishError;
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
use pyo3_stub_gen::derive::{
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
};
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
@@ -131,237 +129,45 @@ mod exception {
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else if let Err(PublishError::MessageTooLarge) = result {
Err(exception::PyMessageTooLargeError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
pub to_swarm: mpsc::Sender<ToSwarm>,
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
#[gen_stub_pyclass_complex_enum]
#[pyclass]
enum PyFromSwarm {
Connection {
peer_id: String,
connected: bool,
},
Message {
origin: String,
topic: String,
data: Py<PyBytes>,
},
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
impl From<FromSwarm> for PyFromSwarm {
fn from(value: FromSwarm) -> Self {
match value {
FromSwarm::Discovered { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: true,
},
FromSwarm::Expired { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: false,
},
FromSwarm::Message { from, topic, data } => Self::Message {
origin: from.to_base58(),
topic: topic,
data: data.pybytes(),
},
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
@@ -375,97 +181,36 @@ impl PyNetworkingHandle {
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
Ok(Self {
swarm: Arc::new(Mutex::new(swarm)),
to_swarm,
})
}
#[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let swarm = Arc::clone(&self.swarm);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
swarm
.try_lock()
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
.next()
.await
.ok_or(PyErr::receiver_channel_closed())
.map(PyFromSwarm::from)
})
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
@@ -475,10 +220,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
self.to_swarm
.send_py(ToSwarm::Subscribe {
topic,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -487,6 +232,7 @@ impl PyNetworkingHandle {
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
.pyerr()
}
/// Unsubscribes from a `GossipSub` topic.
@@ -496,10 +242,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
self.to_swarm
.send_py(ToSwarm::Unsubscribe {
topic,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -518,11 +264,11 @@ impl PyNetworkingHandle {
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
self.to_swarm
.send_py(ToSwarm::Publish {
topic,
data,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -531,64 +277,26 @@ impl PyNetworkingHandle {
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
.map_err(|_| PyErr::receiver_channel_closed())?
.map_err(|e| match e {
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
PublishError::NoPeersSubscribedToTopic => {
PyNoPeersSubscribedToTopicError::new_err()
}
e => PyRuntimeError::new_err(e.to_string()),
})?;
Ok(())
}
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
pyo3_stub_gen::inventory::submit! {
gen_methods_from_python! {
r#"
class PyNetworkingHandle:
async def recv() -> PyFromSwarm: ...
"#
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
@@ -596,10 +304,8 @@ pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<exception::PyMessageTooLargeError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
m.add_class::<PyFromSwarm>()?;
Ok(())
}

View File

@@ -21,9 +21,10 @@ extend = { workspace = true }
delegate = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
async-stream = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
tokio = { workspace = true, features = ["full"] }
# utility dependencies
util = { workspace = true }
@@ -35,3 +36,4 @@ log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -1,7 +1,9 @@
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use libp2p::identity;
use networking::swarm;
use networking::swarm::{FromSwarm, ToSwarm};
use tokio::sync::{mpsc, oneshot};
use tokio::{io, io::AsyncBufReadExt as _};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -11,64 +13,69 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (to_swarm, from_client) = mpsc::channel(20);
// Configure swarm
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
.expect("Swarm creation failed")
.into_stream();
// Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
let (tx, rx) = oneshot::channel();
_ = to_swarm
.send(ToSwarm::Subscribe {
topic: "test-net".to_string(),
result_sender: tx,
})
.await
.expect("should send");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
tokio::task::spawn(async move {
rx.await
.expect("tx not dropped")
.expect("subscribe shouldn't fail");
loop {
if let Ok(Some(line)) = stdin.next_line().await {
let (tx, rx) = oneshot::channel();
if let Err(e) = to_swarm
.send(swarm::ToSwarm::Publish {
topic: "test-net".to_string(),
data: line.as_bytes().to_vec(),
result_sender: tx,
})
.await
{
println!("Send error: {e:?}");
return;
};
match rx.await {
Ok(Err(e)) => println!("Publish error: {e:?}"),
Err(e) => println!("Publish error: {e:?}"),
Ok(_) => {}
}
}
}
});
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
// on gossipsub outgoing
match swarm.next().await {
// on gossipsub incoming
Some(FromSwarm::Discovered { peer_id }) => {
println!("\n\nconnected to {peer_id}\n\n")
}
event = swarm.next() => match event {
// on gossipsub incoming
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
Some(FromSwarm::Expired { peer_id }) => {
println!("\n\ndisconnected from {peer_id}\n\n")
}
Some(FromSwarm::Message { from, topic, data }) => {
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
}
None => {}
}
}
}

View File

@@ -1,9 +1,11 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
use std::pin::Pin;
pub type Swarm = libp2p::Swarm<Behaviour>;
use crate::swarm::transport::tcp_transport;
use crate::{alias, discovery};
pub use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::{Stream, StreamExt};
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::{mpsc, oneshot};
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
@@ -15,8 +17,136 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
// Uses oneshot senders to emulate function calling apis while avoiding requiring unique ownership
// of the Swarm.
pub enum ToSwarm {
Unsubscribe {
topic: String,
result_sender: oneshot::Sender<bool>,
},
Subscribe {
topic: String,
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
},
Publish {
topic: String,
data: Vec<u8>,
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
},
}
pub enum FromSwarm {
Message {
from: PeerId,
topic: String,
data: Vec<u8>,
},
Discovered {
peer_id: PeerId,
},
Expired {
peer_id: PeerId,
},
}
pub struct Swarm {
swarm: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
}
impl Swarm {
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
let Swarm {
mut swarm,
mut from_client,
} = self;
let stream = async_stream::stream! {
loop {
tokio::select! {
msg = from_client.recv() => {
let Some(msg) = msg else { break };
on_message(&mut swarm, msg);
}
event = swarm.next() => {
let Some(event) = event else { break };
if let Some(item) = filter_swarm_event(event) {
yield item;
}
}
}
}
};
Box::pin(stream)
}
}
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
match message {
ToSwarm::Subscribe {
topic,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic));
_ = result_sender.send(result);
}
ToSwarm::Unsubscribe {
topic,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
_ = result_sender.send(result);
}
ToSwarm::Publish {
topic,
data,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.publish(gossipsub::IdentTopic::new(topic), data);
_ = result_sender.send(result);
}
}
}
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
match event {
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message:
gossipsub::Message {
source: Some(peer_id),
topic,
data,
..
},
..
})) => Some(FromSwarm::Message {
from: peer_id,
topic: topic.into_string(),
data,
}),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionEstablished { peer_id, .. },
)) => Some(FromSwarm::Discovered { peer_id }),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
peer_id,
..
})) => Some(FromSwarm::Expired { peer_id }),
_ => None,
}
}
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
pub fn create_swarm(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
@@ -25,7 +155,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
Ok(Swarm { swarm, from_client })
}
mod transport {

View File

@@ -1,6 +1,5 @@
import asyncio
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import current_time
@@ -21,13 +20,9 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
EventId,
# TODO(evan): just for acks, should delete this ASAP
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -38,40 +33,28 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
from exo.utils.task_group import TaskGroup
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[LocalForwarderEvent]
# ack stuff
_global_event_receiver: Receiver[GlobalForwarderEvent]
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
event_sender: Sender[Event]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -123,10 +106,7 @@ class DownloadCoordinator:
try:
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._clear_ofd)
finally:
for task in self.active_downloads.values():
task.cancel()
@@ -134,20 +114,6 @@ class DownloadCoordinator:
def shutdown(self) -> None:
self._tg.cancel_tasks()
# directly copied from worker
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self._out_for_delivery.copy().values():
await self.local_event_sender.send(event)
async def _clear_ofd(self) -> None:
with self._global_event_receiver as events:
async for event in events:
self._out_for_delivery.pop(event.event.event_id, None)
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
@@ -320,23 +286,6 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
self._out_for_delivery[event.event_id] = fe
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -314,9 +314,13 @@ async def fetch_file_list_with_cache(
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
logger.opt(exception=e).warning(
"Ran into exception when fetching file list from HF."
)
if await aios.path.exists(cache_file):
logger.warning(
f"No internet and no cached file list for {model_id} - using local file list"
f"No cached file list for {model_id} - using local file list"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())

View File

@@ -1,98 +0,0 @@
from typing import Any
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import NoopShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadPending,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.utils.channels import channel
# Use the builtin NoopShardDownloader directly it already implements the required abstract interface.
# No additional subclass is needed for this test.
@pytest.mark.anyio
async def test_ack_behaviour():
# Create channels (type Any for simplicity)
_, command_receiver = channel[Any]()
local_sender, _ = channel[Any]()
global_sender, global_receiver = channel[Any]()
# Minimal identifiers
node_id = NodeId()
session_id = SessionId(master_node_id=node_id, election_clock=0)
# Create a dummy model card and shard metadata
model_id = ModelId("test/model")
model_card = ModelCard(
model_id=model_id,
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
# Instantiate the coordinator with the dummy downloader
coord = DownloadCoordinator(
node_id=node_id,
session_id=session_id,
shard_downloader=NoopShardDownloader(),
download_command_receiver=command_receiver,
local_event_sender=local_sender,
_global_event_receiver=global_receiver,
)
async with anyio.create_task_group() as tg:
# Start the forwarding and ackclearing loops
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
# Send a pending download progress event via the internal event sender
pending = DownloadPending(
node_id=node_id,
shard_metadata=shard,
model_directory="/tmp/model",
)
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
# Allow the forwarder to process the event
await anyio.sleep(0.1)
# There should be exactly one entry awaiting ACK
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
# Retrieve the stored LocalForwarderEvent
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
# Simulate receiving a global ack for this event
ack = GlobalForwarderEvent(
origin_idx=0,
origin=node_id,
session=session_id,
event=stored_fe.event,
)
await global_sender.send(ack)
# Give the clearofd task a moment to process the ack
await anyio.sleep(0.1)
# The outfordelivery map should now be empty
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
# Cancel background tasks
tg.cancel_scope.cancel()

View File

@@ -15,6 +15,7 @@ from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.event_router import EventRouter
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
@@ -24,11 +25,13 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
from exo.worker.runner.runner_opts import RunnerOpts
@dataclass
class Node:
router: Router
event_router: EventRouter
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
@@ -38,10 +41,11 @@ class Node:
node_id: NodeId
offline: bool
runner_opts: RunnerOpts
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@classmethod
async def create(cls, args: "Args") -> Self:
@staticmethod
async def create(args: "Args") -> "Node":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -52,20 +56,37 @@ class Node:
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
event_router = EventRouter(
session_id,
command_sender=router.sender(topics.COMMANDS),
external_outbound=router.sender(topics.LOCAL_EVENTS),
external_inbound=router.receiver(topics.GLOBAL_EVENTS),
)
logger.info(f"Starting node {node_id}")
if args.fast_synch is True:
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
logger.info("FAST_SYNCH forced OFF")
runner_opts = RunnerOpts(
fast_synch_override=args.fast_synch,
trust_remote_code_override=args.trust_remote_code,
)
if offline := args.offline:
logger.info(
"Running in OFFLINE mode — no internet checks, local models only"
)
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(offline=args.offline),
exo_shard_downloader(offline=offline),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
offline=args.offline,
# TODO(evan): remove
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
offline=offline,
)
else:
download_coordinator = None
@@ -73,9 +94,8 @@ class Node:
if args.spawn_api:
api = API(
node_id,
session_id,
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
event_receiver=event_router.receiver(),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
@@ -86,9 +106,9 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
session_id,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
runner_opts,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
@@ -99,6 +119,7 @@ class Node:
master = Master(
node_id,
session_id,
event_sender=event_router.sender(),
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
@@ -119,8 +140,9 @@ class Node:
election_result_sender=er_send,
)
return cls(
return Node(
router,
event_router,
download_coordinator,
worker,
election,
@@ -129,6 +151,7 @@ class Node:
api,
node_id,
args.offline,
runner_opts,
)
async def run(self):
@@ -136,6 +159,7 @@ class Node:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.event_router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
@@ -183,6 +207,7 @@ class Node:
self.master = Master(
self.node_id,
result.session_id,
event_sender=self.event_router.sender(),
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
@@ -206,21 +231,24 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
self.event_router.shutdown()
self.event_router = EventRouter(
result.session_id,
self.router.sender(topics.COMMANDS),
self.router.receiver(topics.GLOBAL_EVENTS),
self.router.sender(topics.LOCAL_EVENTS),
)
self._tg.start_soon(self.event_router.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(offline=self.offline),
event_sender=self.event_router.sender(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
offline=self.offline,
# TODO(evan): remove
_global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -228,11 +256,9 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
self.runner_opts,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
@@ -240,7 +266,7 @@ class Node:
)
self._tg.start_soon(self.worker.run)
if self.api:
self.api.reset(result.session_id, result.won_clock)
self.api.reset(result.won_clock, self.event_router.receiver())
else:
if self.api:
self.api.unpause(result.won_clock)
@@ -258,17 +284,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
try:
anyio.run(node.run)
@@ -290,8 +305,11 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool | None = (
None # None = auto, True = force on, False = force off
)
@classmethod
def parse(cls) -> Self:
@@ -358,6 +376,20 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
trust_remote_code_group = parser.add_mutually_exclusive_group()
trust_remote_code_group.add_argument(
"--trust-remote-code",
action="store_true",
dest="trust_remote_code",
default=None,
help="Allow all models to execute custom code",
)
trust_remote_code_group.add_argument(
"--never-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Deny all models from execute custom code",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -140,11 +140,10 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
GlobalForwarderEvent,
IndexedEvent,
TracesMerged,
)
@@ -166,20 +165,12 @@ from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.settings import (
ExoSettings,
load_settings,
)
from exo.shared.types.settings import (
save_settings as save_settings_to_file,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
@@ -203,10 +194,9 @@ class API:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
port: int,
global_event_receiver: Receiver[GlobalForwarderEvent],
event_receiver: Receiver[IndexedEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -217,11 +207,9 @@ class API:
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.event_receiver = event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.last_completed_election: int = 0
self.port = port
@@ -261,17 +249,18 @@ class API:
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup = TaskGroup()
def reset(self, new_session_id: SessionId, result_clock: int):
def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
logger.info("Resetting API State")
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
self.event_receiver.close()
self.event_receiver = event_receiver
self._tg.start_soon(self._apply_state)
def unpause(self, result_clock: int):
logger.info("Unpausing API")
@@ -356,8 +345,6 @@ class API:
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
self.app.get("/onboarding")(self.get_onboarding)
self.app.post("/onboarding")(self.complete_onboarding)
self.app.get("/settings")(self.get_settings)
self.app.post("/settings")(self.save_settings)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -1615,7 +1602,7 @@ class API:
finally:
self._event_log.close()
self.command_sender.close()
self.global_event_receiver.close()
self.event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
@@ -1632,38 +1619,31 @@ class API:
)
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
with self.event_receiver as events:
async for i_event in events:
self._event_log.append(i_event.event)
self.state = apply(self.state, i_event)
event = i_event.event
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(event.command_id, None)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
def _save_merged_trace(self, event: TracesMerged) -> None:
traces = [
@@ -1834,13 +1814,3 @@ class API:
ONBOARDING_COMPLETE_FILE.parent.mkdir(parents=True, exist_ok=True)
ONBOARDING_COMPLETE_FILE.write_text("true")
return JSONResponse({"completed": True})
async def get_settings(self) -> JSONResponse:
settings = load_settings()
return JSONResponse(settings.model_dump())
async def save_settings(self, request: Request) -> JSONResponse:
body = cast(object, await request.json())
settings = ExoSettings.model_validate(body)
save_settings_to_file(settings)
return JSONResponse(settings.model_dump())

View File

@@ -60,7 +60,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import MultiSourceBuffer
from exo.utils.task_group import TaskGroup
@@ -72,25 +72,21 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
event_sender: Sender[Event],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = TaskGroup()
self.node_id = node_id
self.session_id = session_id
self.state = State()
self._tg: TaskGroup = TaskGroup()
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
local_event_receiver.clone_sender()
)
self.event_sender = event_sender
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
@@ -104,15 +100,12 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -409,22 +402,6 @@ class Master:
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
LocalForwarderEvent(
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly

View File

@@ -17,6 +17,7 @@ from exo.shared.types.commands import (
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
GlobalForwarderEvent,
IndexedEvent,
InstanceCreated,
@@ -50,6 +51,22 @@ async def test_master():
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
ev_send, ev_recv = channel[Event]()
async def mock_event_router():
idx = 0
sid = SystemId()
with ev_recv as master_events:
async for event in master_events:
await local_event_sender.send(
LocalForwarderEvent(
origin=sid,
origin_idx=idx,
session=session_id,
event=event,
)
)
idx += 1
all_events: list[IndexedEvent] = []
@@ -67,6 +84,7 @@ async def test_master():
master = Master(
node_id,
session_id,
event_sender=ev_send,
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
@@ -75,6 +93,7 @@ async def test_master():
logger.info("run the master")
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
tg.start_soon(mock_event_router)
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
@@ -197,4 +216,5 @@ async def test_master():
input=[InputMessage(role="user", content="Hello, how are you?")],
)
ev_send.close()
await master.shutdown()

View File

@@ -1,6 +1,4 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo_pyo3_bindings import PyFromSwarm
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
connected: bool
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
return cls(node_id=NodeId(update.peer_id), connected=update.connected)

View File

@@ -0,0 +1,161 @@
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import BrokenResourceError, ClosedResourceError
from anyio.abc import CancelScope
from loguru import logger
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
LocalForwarderEvent,
)
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
@dataclass
class EventRouter:
session_id: SessionId
command_sender: Sender[ForwarderCommand]
external_inbound: Receiver[GlobalForwarderEvent]
external_outbound: Sender[LocalForwarderEvent]
_system_id: SystemId = field(init=False, default_factory=SystemId)
internal_outbound: list[Sender[IndexedEvent]] = field(
init=False, default_factory=list
)
event_buffer: OrderedBuffer[Event] = field(
init=False, default_factory=OrderedBuffer
)
out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(
init=False, default_factory=dict
)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_nack_cancel_scope: CancelScope | None = field(init=False, default=None)
_nack_attempts: int = field(init=False, default=0)
_nack_base_seconds: float = field(init=False, default=0.5)
_nack_cap_seconds: float = field(init=False, default=10.0)
async def run(self):
try:
async with self._tg as tg:
tg.start_soon(self._run_ext_in)
tg.start_soon(self._simple_retry)
finally:
self.external_outbound.close()
for send in self.internal_outbound:
send.close()
# can make this better in future
async def _simple_retry(self):
while True:
await anyio.sleep(1 + random())
# list here is a shallow clone for shared mutation
for e_id, (time, event) in list(self.out_for_delivery.items()):
if anyio.current_time() > time + 5:
self.out_for_delivery[e_id] = (anyio.current_time(), event)
await self.external_outbound.send(event)
def sender(self) -> Sender[Event]:
send, recv = channel[Event]()
if self._tg.is_running():
self._tg.start_soon(self._ingest, SystemId(), recv)
else:
self._tg.queue(self._ingest, SystemId(), recv)
return send
def receiver(self) -> Receiver[IndexedEvent]:
send, recv = channel[IndexedEvent]()
self.internal_outbound.append(send)
return recv
def shutdown(self) -> None:
self._tg.cancel_tasks()
async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):
idx = 0
with recv as events:
async for event in events:
f_ev = LocalForwarderEvent(
origin_idx=idx,
origin=system_id,
session=self.session_id,
event=event,
)
idx += 1
await self.external_outbound.send(f_ev)
self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)
async def _run_ext_in(self):
buf = OrderedBuffer[Event]()
with self.external_inbound as events:
async for event in events:
if event.session != self.session_id:
continue
if event.origin != self.session_id.master_node_id:
continue
buf.ingest(event.origin_idx, event.event)
event_id = event.event.event_id
if event_id in self.out_for_delivery:
self.out_for_delivery.pop(event_id)
drained = buf.drain_indexed()
if drained:
self._nack_attempts = 0
if self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
if not drained and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(self._nack_request, buf.next_idx_to_release)
continue
for idx, event in drained:
to_clear = set[int]()
for i, sender in enumerate(self.internal_outbound):
try:
await sender.send(IndexedEvent(idx=idx, event=event))
except (ClosedResourceError, BrokenResourceError):
to_clear.add(i)
for i in sorted(to_clear, reverse=True):
self.internal_outbound.pop(i)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None

View File

@@ -17,6 +17,7 @@ from exo_pyo3_bindings import (
MessageTooLargeError,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyFromSwarm,
)
from filelock import FileLock
from loguru import logger
@@ -121,7 +122,8 @@ class Router:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
if self._tg.is_running():
await self._networking_subscribe(topic.topic)
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None)
@@ -152,8 +154,10 @@ class Router:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# subscribe to pending topics
for topic in self.topic_routers:
await self._networking_subscribe(topic)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
@@ -176,41 +180,40 @@ class Router:
async def _networking_recv(self):
try:
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {topic}"
)
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
from_swarm = await self._net.recv()
logger.debug(from_swarm)
match from_swarm:
case PyFromSwarm.Message(origin, topic, data):
logger.trace(
f"Received message on {topic} from {origin} with payload {data}"
)
if topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {topic}"
)
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
case PyFromSwarm.Connection():
message = ConnectionMessage.from_update(from_swarm)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
case _:
logger.critical(
"failed to exhaustively check FromSwarm messages - logic error"
)
except Exception as exception:
logger.opt(exception=exception).error(
"Gossipsub receive loop terminated unexpectedly"
)
raise
async def _networking_recv_connection_messages(self):
try:
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
except Exception as exception:
logger.opt(exception=exception).error(
"Connection update receive loop terminated unexpectedly"
)
raise
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
tg.start_soon(election.run)
# Send any connection message object; we close quickly to cancel before result creation
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
# Expect a broadcast for the new round at clock=1
while True:

View File

@@ -1,11 +1,6 @@
import ctypes
import sys
from math import ceil
from typing import Self, overload
import psutil
from exo.shared.logging import logger
from exo.utils.pydantic_ext import FrozenModel
@@ -154,67 +149,3 @@ class Memory(FrozenModel):
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
def _load_memory_settings() -> tuple[float, "Memory"]:
"""Load memory threshold and floor from settings (lazy import to avoid circular dep)."""
from exo.shared.types.settings import load_settings
s = load_settings()
return s.memory.memory_threshold, Memory.from_gb(s.memory.memory_floor_gb)
_libc: ctypes.CDLL | None = None
def _macos_memorystatus_level() -> int:
global _libc # noqa: PLW0603
if _libc is None:
_libc = ctypes.CDLL("/usr/lib/libSystem.B.dylib")
level = ctypes.c_int(0)
size = ctypes.c_size_t(ctypes.sizeof(ctypes.c_int))
ret: int = _libc.sysctlbyname( # pyright: ignore[reportAny]
b"kern.memorystatus_level",
ctypes.byref(level),
ctypes.byref(size),
None,
ctypes.c_size_t(0),
)
if ret != 0:
raise OSError("sysctlbyname kern.memorystatus_level failed")
return level.value
def _get_macos_memory_pressure() -> float:
try:
return 1.0 - _macos_memorystatus_level() / 100.0
except (OSError, FileNotFoundError):
logger.warning("Using fallback memory pressure")
return _fallback_memory_pressure()
def _fallback_memory_pressure() -> float:
vm = psutil.virtual_memory()
return 1.0 - vm.available / vm.total
def get_memory_pressure() -> float:
if sys.platform == "darwin":
return _get_macos_memory_pressure()
return _fallback_memory_pressure()
def get_memory_limit() -> Memory:
threshold, floor = _load_memory_settings()
total = psutil.virtual_memory().total
safety = min(int(total * (1 - threshold)), floor.in_bytes)
return Memory.from_bytes(total - safety)
def get_memory_available_locally() -> Memory:
total = Memory.from_bytes(psutil.virtual_memory().total)
return get_memory_limit() - total * get_memory_pressure()
def get_memory_pressure_threshold() -> float:
total = psutil.virtual_memory().total
return get_memory_limit().in_bytes / total

View File

@@ -2,6 +2,8 @@
from collections.abc import Sequence
from mlx import core as mx
from mlx import nn as nn
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: KVCacheType | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -1,121 +0,0 @@
import os
import tomllib
from typing import Literal
import psutil
from pydantic import ConfigDict, Field, ValidationError
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.logging import logger
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
class MemorySettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="forbid",
strict=False,
)
oom_prevention: bool = False
memory_threshold: float = Field(default_factory=_default_memory_threshold, ge=0.0, le=1.0)
memory_floor_gb: float = Field(default=5.0, ge=0.0)
class GenerationSettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="forbid",
strict=False,
)
prefill_step_size: int = Field(default=4096, ge=1)
max_tokens: int = Field(default=32168, ge=1)
kv_cache_bits: Literal[4, 8] | None = None
class ExoSettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="ignore",
strict=False,
)
memory: MemorySettings = Field(default_factory=MemorySettings)
generation: GenerationSettings = Field(default_factory=GenerationSettings)
_cached_settings: ExoSettings | None = None
_cached_mtime: float = 0.0
def load_settings() -> ExoSettings:
global _cached_settings, _cached_mtime # noqa: PLW0603
try:
mtime = EXO_CONFIG_FILE.stat().st_mtime
if _cached_settings is not None and mtime == _cached_mtime:
return _cached_settings
with open(EXO_CONFIG_FILE, "rb") as f:
data = tomllib.load(f)
settings = ExoSettings.model_validate(data)
_cached_mtime = mtime
except FileNotFoundError:
settings = ExoSettings()
except (tomllib.TOMLDecodeError, ValidationError) as e:
logger.warning(f"Invalid config file {EXO_CONFIG_FILE}: {e}")
settings = ExoSettings()
# Env vars override config file for backward compat.
env_threshold = os.environ.get("EXO_MEMORY_THRESHOLD")
if env_threshold is not None:
settings = settings.model_copy(
update={"memory": settings.memory.model_copy(update={"memory_threshold": float(env_threshold)})}
)
env_floor = os.environ.get("EXO_MEMORY_FLOOR")
if env_floor is not None:
settings = settings.model_copy(
update={"memory": settings.memory.model_copy(update={"memory_floor_gb": float(env_floor)})}
)
_cached_settings = settings
return settings
def save_settings(settings: ExoSettings) -> None:
global _cached_settings, _cached_mtime # noqa: PLW0603
EXO_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
lines = [
"[memory]",
f"oom_prevention = {'true' if settings.memory.oom_prevention else 'false'}",
f"memory_threshold = {settings.memory.memory_threshold}",
f"memory_floor_gb = {settings.memory.memory_floor_gb}",
"",
"[generation]",
f"prefill_step_size = {settings.generation.prefill_step_size}",
f"max_tokens = {settings.generation.max_tokens}",
]
if settings.generation.kv_cache_bits is not None:
lines.append(f"kv_cache_bits = {settings.generation.kv_cache_bits}")
EXO_CONFIG_FILE.write_text("\n".join(lines) + "\n")
_cached_settings = settings
_cached_mtime = EXO_CONFIG_FILE.stat().st_mtime

View File

@@ -12,7 +12,7 @@ from anyio import fail_after, open_process, to_thread
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from pydantic import ConfigDict, ValidationError
from pydantic import ValidationError
from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR
from exo.shared.types.memory import Memory
@@ -295,8 +295,6 @@ class ThunderboltBridgeInfo(TaggedModel):
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
model_config = ConfigDict(extra="ignore")
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)

View File

@@ -1,17 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -49,6 +49,21 @@ TimeoutCallback = Callable[[], None]
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
def flush_prefill_sends() -> None:
for output, dst, group in _pending_prefill_sends:
sent = mx.distributed.send(output, dst, group=group)
mx.async_eval(sent)
_pending_prefill_sends.clear()
def clear_prefill_sends() -> None:
# Discard pending sends (e.g. on cancellation).
_pending_prefill_sends.clear()
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
@@ -150,6 +165,7 @@ class PipelineLastLayer(CustomMlxLayer):
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
self.is_prefill: bool = False
self.queue_sends: bool = False
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
@@ -163,9 +179,14 @@ class PipelineLastLayer(CustomMlxLayer):
mx.eval(output)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if self.queue_sends:
_pending_prefill_sends.append(
(output, (self.r + 1) % self.s, self.group)
)
else:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
@@ -190,6 +211,12 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
layer.is_prefill = is_prefill
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
for layer in model.layers: # type: ignore
if isinstance(layer, PipelineLastLayer):
layer.queue_sends = queue_sends
def get_inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
if isinstance(inner, nn.Module):

View File

@@ -1,6 +1,8 @@
import os
from copy import deepcopy
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -10,14 +12,30 @@ from mlx_lm.models.cache import (
)
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory, get_memory_pressure
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.settings import load_settings
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType, Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
class CacheSnapshot:
"""Snapshot of states at a known token position."""
@@ -73,15 +91,6 @@ class KVPrefixCache:
self._snapshots.clear()
self._last_used.clear()
def force_evict_all(self) -> int:
count = len(self.caches)
self.clear()
if count > 0:
logger.info(
f"Force-evicted all {count} prefix cache entries due to memory pressure"
)
return count
def add_kv_cache(
self,
prompt_tokens: mx.array,
@@ -207,7 +216,7 @@ class KVPrefixCache:
# Evict LRU entries until below threshold
while (
len(self.caches) > 0
and self.get_memory_used_percentage() > load_settings().memory.memory_threshold
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
@@ -220,7 +229,7 @@ class KVPrefixCache:
)
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_pressure()
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
@@ -244,9 +253,9 @@ def trim_cache(
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
c.state = [None] * len(c.state)
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
c.trim(num_tokens)
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -289,47 +298,15 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def _measure_single_cache_bytes(
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
if isinstance(entry, CacheList):
return sum(
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
for c in entry.caches
)
total = 0
if isinstance(entry, ArraysCache):
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
for arr in state: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
total = 0
for attr_name in ("keys", "values"):
val: object = getattr(entry, attr_name, None)
if val is None:
continue
if isinstance(val, mx.array):
total += val.nbytes
elif isinstance(val, (tuple, list)):
for arr in val: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def measure_cache_bytes(cache: KVCacheType) -> int:
return sum(_measure_single_cache_bytes(c) for c in cache)
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> Memory:
offset = cache_length(cache)
if offset == 0:
return Memory.from_bytes(0)
return Memory.from_bytes(measure_cache_bytes(cache) // offset)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
@@ -342,14 +319,13 @@ def make_kv_cache(
return model.make_cache() # type: ignore
if max_kv_size is None:
kv_cache_bits = load_settings().generation.kv_cache_bits
if kv_cache_bits is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=kv_cache_bits)
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:

View File

@@ -1,10 +1,14 @@
import functools
import math
import time
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.generate import (
maybe_quantize_kv_cache,
stream_generate,
)
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -18,33 +22,37 @@ from exo.shared.types.api import (
Usage,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory, get_memory_available_locally
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.settings import load_settings
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType, Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
clear_prefill_sends,
flush_prefill_sends,
set_pipeline_prefill,
set_pipeline_queue_sends,
)
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
encode_prompt,
has_non_kv_caches,
make_kv_cache,
measure_kv_cache_bytes_per_token,
snapshot_ssm_states,
)
from exo.worker.engines.mlx.constants import (
DEFAULT_TOP_LOGPROBS,
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
fix_unmatched_think_end_tokens,
mx_any,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -58,6 +66,130 @@ class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def _has_pipeline_communication_layer(model: Model):
for layer in model.layers:
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
return True
return False
def pipeline_parallel_prefill(
model: Model,
prompt: mx.array,
prompt_cache: KVCacheType,
prefill_step_size: int,
kv_group_size: int | None,
kv_bits: int | None,
prompt_progress_callback: Callable[[int, int], None],
distributed_prompt_progress_callback: Callable[[], None] | None,
group: mx.distributed.Group,
) -> None:
"""Prefill the KV cache for pipeline parallel with overlapping stages.
Each rank processes the full prompt through its real cache, offset by leading
and trailing dummy iterations.
Total iterations per rank = N_real_chunks + world_size - 1:
- rank r leading dummies (skip_pipeline_io, throwaway cache)
- N_real_chunks real (pipeline IO active, real cache)
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
e.g.
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
iter 0: R0 real[0:4096] R1 dummy
iter 1: R0 real[4096:8192] R1 real[0:4096]
iter 2: R0 real[8192:10240] R1 real[4096:8192]
iter 3: R0 dummy R1 real[8192:10240]
This function is designed to match mlx_lm's stream_generate exactly in terms of
side effects (given the same prefill step size)
"""
prefill_step_size = prefill_step_size // min(4, group.size())
quantize_cache_fn: Callable[..., None] = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=0,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
_prompt_cache: KVCacheType = prompt_cache
rank = group.rank()
world_size = group.size()
# Build list of real prompt chunk sizes
total = len(prompt)
real_chunk_sizes: list[int] = []
remaining = total - 1
while remaining:
n = min(prefill_step_size, remaining)
real_chunk_sizes.append(n)
remaining -= n
n_real = len(real_chunk_sizes)
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
n_leading = rank
n_trailing = world_size - 1 - rank
n_total = n_leading + n_real + n_trailing
t_start = time.perf_counter()
processed = 0
logger.info(
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
)
clear_prefill_sends()
# Initial callback matching generate_step
prompt_progress_callback(0, total)
try:
with mx.stream(generation_stream):
for _ in range(n_leading):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
for i in range(n_real):
chunk_size = real_chunk_sizes[i]
model(
prompt[processed : processed + chunk_size][None],
cache=_prompt_cache,
)
quantize_cache_fn(_prompt_cache)
processed += chunk_size
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
flush_prefill_sends()
prompt_progress_callback(processed, total)
for _ in range(n_trailing):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
finally:
clear_prefill_sends()
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
for _ in range(2):
with mx.stream(generation_stream):
model(prompt[-1:][None], cache=_prompt_cache)
quantize_cache_fn(_prompt_cache)
flush_prefill_sends()
assert _prompt_cache is not None
mx.eval([c.state for c in _prompt_cache]) # type: ignore
# Final callback matching generate_step
prompt_progress_callback(total, total)
logger.info(
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
@@ -66,6 +198,7 @@ def prefill(
cache: KVCacheType,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
distributed_prompt_progress_callback: Callable[[], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -97,31 +230,57 @@ def prefill(
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
def combined_progress_callback(processed: int, total: int) -> None:
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
progress_callback(processed, total)
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
is_pipeline = _has_pipeline_communication_layer(model)
prefill_step_size = 4096
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=load_settings().generation.prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
if is_pipeline and num_tokens >= prefill_step_size:
set_pipeline_queue_sends(model, queue_sends=True)
assert group is not None, "Pipeline prefill requires a distributed group"
pipeline_parallel_prefill(
model=model,
prompt=prompt_tokens,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)
else:
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=combined_progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
raise
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
# stream_generate added 1 extra generated token to the cache, so we should trim it.
@@ -134,7 +293,7 @@ def prefill(
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
c.trim(2)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -150,8 +309,7 @@ def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None,
) -> tuple[int, Memory]:
"""Run warmup inference and tokens_generated and bytes_per_token"""
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
warmup_prompt = apply_chat_template(
@@ -190,12 +348,9 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
logger.info(f"Measured KV cache cost: {bytes_per_token} per token")
mx_barrier(group)
return tokens_generated, bytes_per_token
return tokens_generated
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
@@ -273,33 +428,6 @@ def extract_top_logprobs(
return selected_logprob, top_logprob_items
def _check_memory_budget(
bytes_per_token: Memory,
total_sequence_tokens: int,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> str | None:
if bytes_per_token.in_bytes == 0:
return None
estimated = bytes_per_token * total_sequence_tokens
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
mx.clear_cache()
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
return (
"Not enough memory for this conversation. "
"Please start a new conversation or compact your messages."
)
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -308,10 +436,8 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
bytes_per_token: Memory | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
if bytes_per_token is None:
bytes_per_token = Memory()
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
@@ -343,23 +469,6 @@ def mlx_generate(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
if bytes_per_token.in_bytes > 0 and load_settings().memory.oom_prevention:
oom_error = _check_memory_budget(
bytes_per_token=bytes_per_token,
total_sequence_tokens=len(all_prompt_tokens),
kv_prefix_cache=kv_prefix_cache,
group=group,
)
if oom_error is not None:
logger.warning(f"OOM prevention (prefill): {oom_error}")
yield GenerationResponse(
text=oom_error,
token=0,
finish_reason="error",
usage=None,
)
return
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -389,13 +498,14 @@ def mlx_generate(
caches,
group,
on_prefill_progress,
distributed_prompt_progress_callback,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
# stream_generate starts from the last token
last_token = prompt_tokens[-2:]
max_tokens = task.max_output_tokens or load_settings().generation.max_tokens
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()

View File

@@ -2,6 +2,7 @@ import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, cast
@@ -39,6 +40,7 @@ from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -51,7 +53,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
LayerLoadedCallback,
TimeoutCallback,
@@ -98,14 +99,13 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
coordination_file = None
try:
with tempfile.TemporaryDirectory() as tmpdir:
coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,9 +128,6 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
@@ -150,10 +147,6 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(
@@ -174,10 +167,12 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
trust_remote_code: bool | None,
) -> tuple[Model, TokenizerWrapper]:
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
@@ -196,12 +191,10 @@ def load_mlx_items(
mx.eval(model)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
model = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
@@ -212,6 +205,14 @@ def load_mlx_items(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
tokenizer = load_tokenizer_for_model_id(
bound_instance.bound_shard.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code
if trust_remote_code is not None
else bound_instance.bound_shard.model_card.trust_remote_code,
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
@@ -224,9 +225,8 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> tuple[nn.Module, TokenizerWrapper]:
) -> nn.Module:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
@@ -248,8 +248,6 @@ def shard_and_load(
assert isinstance(model, nn.Module)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
@@ -288,16 +286,7 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
return model
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:

View File

@@ -1,9 +1,9 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from random import random
import anyio
from anyio import CancelScope, fail_after
from anyio import fail_after
from loguru import logger
from exo.download.download_utils import resolve_model_in_path
@@ -13,17 +13,13 @@ from exo.shared.types.api import ImageEditsTaskParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
LocalForwarderEvent,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
@@ -46,56 +42,39 @@ from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_opts import RunnerOpts
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@dataclass
class Worker:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
global_event_receiver: Receiver[GlobalForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
node_id: NodeId
runner_opts: RunnerOpts
event_receiver: Receiver[IndexedEvent]
event_sender: Sender[Event]
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand]
download_command_sender: Sender[ForwarderDownloadCommand]
state: State = field(init=False, default_factory=State)
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_system_id: SystemId = field(init=False, default_factory=SystemId)
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
# Buffer for input image chunks (for image editing)
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
init=False, default_factory=dict
)
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
_download_backoff: KeyedBackoff[ModelId] = field(
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
)
async def run(self):
logger.info("Starting Worker")
@@ -108,14 +87,12 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
@@ -133,47 +110,22 @@ class Worker:
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
if event_id in self.out_for_delivery:
del self.out_for_delivery[event_id]
with self.event_receiver as events:
async for event in events:
# 2. for each event, apply it to the state
indexed_events = self.event_buffer.drain_indexed()
if indexed_events:
self._nack_attempts = 0
self.state = apply(self.state, event=event)
event = event.event
if not indexed_events and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
continue
elif indexed_events and self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
@@ -325,46 +277,10 @@ class Worker:
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
runner_opts=self.runner_opts,
bound_instance=task.bound_instance,
event_sender=self.event_sender.clone(),
)
@@ -372,21 +288,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
async def _poll_connection_updates(self):
while True:
edges = set(

View File

@@ -1,4 +1,5 @@
import os
import resource
import loguru
@@ -8,10 +9,13 @@ from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from .runner_opts import RunnerOpts
logger: "loguru.Logger" = loguru.logger
def entrypoint(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
@@ -20,12 +24,17 @@ def entrypoint(
) -> None:
global logger
logger = _logger
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
fast_synch_override = runner_opts.fast_synch_override
if fast_synch_override is not None:
if fast_synch_override:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
@@ -36,7 +45,7 @@ def entrypoint(
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -1,5 +1,4 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
@@ -66,6 +65,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
@@ -183,14 +183,12 @@ def _send_image_chunk(
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,

View File

@@ -1,5 +1,4 @@
import math
import resource
import time
from collections.abc import Generator
from functools import cache
@@ -31,12 +30,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.memory import (
Memory,
get_memory_pressure,
get_memory_pressure_threshold,
)
from exo.shared.types.settings import load_settings
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -69,7 +63,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
@@ -85,19 +78,18 @@ from exo.worker.engines.mlx.utils_mlx import (
mx_any,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
from .tool_parsers import ToolParser, make_mlx_parser
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
@@ -120,7 +112,6 @@ def main(
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
bytes_per_token = Memory.from_bytes(0)
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -201,6 +192,7 @@ def main(
group,
on_timeout=on_model_load_timeout,
on_layer_loaded=on_layer_loaded,
trust_remote_code=runner_opts.trust_remote_code_override,
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
@@ -232,14 +224,12 @@ def main(
assert tokenizer
t = time.monotonic()
toks, bytes_per_token = warmup_inference(
toks = warmup_inference(
model=cast(Model, inference_model),
tokenizer=tokenizer,
group=group,
)
logger.info(
f"warmed up by generating {toks} tokens, {bytes_per_token}/token for KV cache"
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
@@ -283,8 +273,6 @@ def main(
def on_prefill_progress(
processed: int,
total: int,
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
@@ -297,6 +285,11 @@ def main(
),
)
)
def distributed_prompt_progress_callback(
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
@@ -318,8 +311,8 @@ def main(
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
bytes_per_token=bytes_per_token,
)
if tokenizer.has_thinking:
@@ -346,7 +339,6 @@ def main(
completion_tokens = 0
tokens_since_last_cancel_check = check_for_cancel_every
oom_stopped = False
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
@@ -355,15 +347,7 @@ def main(
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
oom_local = (
load_settings().memory.oom_prevention
and bytes_per_token.in_bytes > 0
and get_memory_pressure()
> get_memory_pressure_threshold()
)
if mx_any(want_to_cancel or oom_local, group):
if not want_to_cancel:
oom_stopped = True
if mx_any(want_to_cancel, group):
break
match response:
@@ -419,21 +403,6 @@ def main(
)
)
if oom_stopped and device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=model_id,
error_message=(
"Generation stopped: running out of memory. "
"Please start a new conversation or compact "
"your messages."
),
),
)
)
except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit?

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class RunnerOpts:
fast_synch_override: bool | None
trust_remote_code_override: bool | None

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.runner_opts import RunnerOpts
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -62,6 +63,7 @@ class RunnerSupervisor:
def create(
cls,
*,
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: Sender[Event],
initialize_timeout: float = 400,
@@ -73,6 +75,7 @@ class RunnerSupervisor:
runner_process = mp.Process(
target=entrypoint,
args=(
runner_opts,
bound_instance,
ev_send,
task_recv,

View File

@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load

View File

@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
@@ -143,7 +143,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
# Cache should now hold the prompt tokens minus one
@@ -164,7 +171,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -200,7 +214,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
short_tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -245,7 +266,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -285,7 +313,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -513,7 +548,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -538,7 +582,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -0,0 +1,512 @@
# type: ignore
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
then verifies that the prompt_progress_callback sequences are identical
and that generated text matches.
"""
import json
import multiprocessing as mp
import os
import tempfile
import traceback
from typing import Any, cast
import pytest
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
TOTAL_LAYERS = 24
MAX_TOKENS = 10
SEED = 42
TEMPERATURE = 0.0
def _model_card() -> ModelCard:
return ModelCard(
model_id=ModelId(MODEL_ID),
storage_size=Memory.from_gb(12),
n_layers=TOTAL_LAYERS,
hidden_size=2880,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
)
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
base_text = "The quick brown fox jumps over the lazy dog. "
base_toks = tokenizer.encode(base_text)
repeats = (prompt_tokens // len(base_toks)) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
model=MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
seed=SEED,
)
prompt = apply_chat_template(tokenizer, task)
return prompt, task
# ---------------------------------------------------------------------------
# Single-device process: uses stream_generate path (no pipeline layers)
# ---------------------------------------------------------------------------
def _run_single_device(
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
try:
import mlx.core as mx
from mlx_lm.utils import load_model
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
build_model_path,
get_tokenizer,
)
model_path = build_model_path(ModelId(MODEL_ID))
model, _ = load_model(model_path, lazy=True, strict=False)
mx.eval(model)
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
dummy_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=TOTAL_LAYERS,
n_layers=TOTAL_LAYERS,
)
tokenizer = get_tokenizer(model_path, dummy_meta)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=None,
on_prefill_progress=on_progress,
):
generated_text += response.text
if response.finish_reason is not None:
break
# Also record the token count that prefill() received (prompt_tokens[:-1])
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Pipeline device process: uses _pipeline_prefill_cache path
# ---------------------------------------------------------------------------
def _run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
try:
import mlx.core as mx
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
group = mx.distributed.init(backend="ring", strict=True)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=TOTAL_LAYERS,
)
model, tokenizer = shard_and_load(
shard_meta, group, on_timeout=None, on_layer_loaded=None
)
model = cast(Any, model)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
def distributed_prompt_progress_callback(_group: Any = group) -> None:
from exo.worker.engines.mlx.utils_mlx import mx_any
mx_any(False, _group)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
on_prefill_progress=on_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
):
generated_text += response.text
if response.finish_reason is not None:
break
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
rank,
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
def _create_hostfile(world_size: int, base_port: int) -> str:
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
return f.name
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
"""Run single-device (stream_generate) prefill and return results."""
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join(timeout=5)
pytest.fail("Single-device process timed out")
assert not result_queue.empty(), "Single-device process produced no result"
success, data = result_queue.get()
assert success, f"Single-device process failed:\n{data}"
return data
def _run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
base_port: int,
timeout: int = 120,
) -> dict[int, dict[str, Any]]:
"""Run pipeline prefill across ranks and return per-rank results."""
world_size = len(layer_splits)
hostfile_path = _create_hostfile(world_size, base_port)
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
try:
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=_run_pipeline_device,
args=(
rank,
world_size,
hostfile_path,
layer_splits,
prompt_tokens,
result_queue,
),
)
p.start()
processes.append(p)
for p in processes:
p.join(timeout=timeout)
timed_out = any(p.is_alive() for p in processes)
for p in processes:
if p.is_alive():
p.terminate()
p.join(timeout=5)
assert not timed_out, "Pipeline processes timed out"
results: dict[int, dict[str, Any]] = {}
while not result_queue.empty():
rank, success, data = result_queue.get()
assert success, f"Pipeline rank {rank} failed:\n{data}"
results[rank] = data
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
)
return results
finally:
os.unlink(hostfile_path)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not MODEL_PATH.exists(),
reason=f"GPT-OSS model not found at {MODEL_PATH}",
),
]
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
class TestPipelineNoDeadlock:
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
@pytest.mark.parametrize(
"layer_splits,prompt_tokens",
[
(LAYER_SPLITS_2WAY, 128),
(LAYER_SPLITS_2WAY, 4096),
(LAYER_SPLITS_2WAY, 8192),
(LAYER_SPLITS_2WAY, 16384),
(LAYER_SPLITS_4WAY, 128),
(LAYER_SPLITS_4WAY, 4096),
(LAYER_SPLITS_4WAY, 8192),
(LAYER_SPLITS_4WAY, 16384),
],
ids=[
"2rank_128tok",
"2rank_4096tok",
"2rank_8192tok",
"2rank_16384tok",
"4rank_128tok",
"4rank_4096tok",
"4rank_8192tok",
"4rank_16384tok",
],
)
def test_no_deadlock(
self,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
) -> None:
"""Pipeline must complete without deadlock at various prompt lengths."""
pipeline_results = _run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=prompt_tokens,
base_port=29650,
timeout=60,
)
# If we get here, no deadlock. Verify all ranks produced output.
for rank, pipe_data in sorted(pipeline_results.items()):
assert pipe_data["text"], f"Rank {rank} produced no output text"
class TestPipelinePrefillCallbacks:
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500, 5000],
ids=["short_50", "medium_500", "long_5000"],
)
def test_callbacks_match(self, prompt_tokens: int) -> None:
"""All pipeline ranks must produce identical callback sequences."""
# Run 4-rank pipeline
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29700,
timeout=180,
)
# All ranks must agree on prefill token count and callback sequence
rank0_data = pipeline_results[0]
rank0_callbacks = rank0_data["callbacks"]
prefill_count = rank0_data["prefill_token_count"]
for rank, pipe_data in sorted(pipeline_results.items()):
pipe_callbacks = pipe_data["callbacks"]
assert pipe_data["prefill_token_count"] == prefill_count, (
f"Rank {rank} prefill token count mismatch: "
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
)
assert pipe_callbacks == rank0_callbacks, (
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
f"(prefill M={prefill_count}):\n"
f" pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\n"
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
)
# Structural checks: starts with (0, M), ends with (M, M), monotonically increasing
assert rank0_callbacks[0] == (0, prefill_count), (
f"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}"
)
assert rank0_callbacks[-1] == (prefill_count, prefill_count), (
f"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}"
)
for i in range(1, len(rank0_callbacks)):
assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (
f"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}"
)
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500],
ids=["short_50", "medium_500"],
)
def test_output_matches(self, prompt_tokens: int) -> None:
"""Pipeline-generated text must match single-device output."""
single = _run_single_device_test(prompt_tokens, timeout=180)
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29800,
timeout=180,
)
single_text = single["text"]
# The last rank produces the final logits, so its output should match.
# Due to SDPA tiling non-determinism, allow minor differences in text.
last_rank = max(pipeline_results.keys())
pipe_text = pipeline_results[last_rank]["text"]
# For deterministic sampling (temp=0.0), outputs should match exactly
# or be very close. Log both for debugging even if they match.
if single_text != pipe_text:
# Find first divergence point
min_len = min(len(single_text), len(pipe_text))
diverge_idx = next(
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
min_len,
)
pytest.fail(
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
f" single-device: {single_text!r}\n"
f" pipeline R{last_rank}: {pipe_text!r}"
)
class TestPipelineCallbacksStructure:
"""Verify structural properties of callbacks independent of model output."""
def test_callback_structure_matches_generate_step(self) -> None:
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
prompt_tokens = 200
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29900,
timeout=180,
)
for rank, pipe_data in sorted(pipeline_results.items()):
callbacks = pipe_data["callbacks"]
m = pipe_data["prefill_token_count"]
assert m > 0, f"Rank {rank}: prefill token count is 0"
assert callbacks[0] == (0, m), (
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
)
assert callbacks[-1] == (m, m), (
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
)
if len(callbacks) > 2:
second_to_last = callbacks[-2]
assert second_to_last[0] < m, (
f"Rank {rank}: second-to-last callback should report < {m}, "
f"got {second_to_last}"
)
# All callbacks must have total == M
for i, (_, total) in enumerate(callbacks):
assert total == m, (
f"Rank {rank}: callback {i} has total={total}, expected {m}"
)
# processed values must be non-decreasing
processed_vals = [p for p, _ in callbacks]
for i in range(1, len(processed_vals)):
assert processed_vals[i] >= processed_vals[i - 1], (
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
f"{processed_vals}"
)
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
for i in range(1, len(callbacks)):
assert callbacks[i] != callbacks[i - 1], (
f"Rank {rank}: duplicate consecutive callback at index {i}: "
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
)

View File

@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (

View File

@@ -15,7 +15,6 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -41,6 +40,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.runner.runner_opts import RunnerOpts
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -115,9 +115,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(
mlx_runner, "warmup_inference", make_nothin((1, Memory.from_bytes(0)))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
@@ -187,6 +185,7 @@ def _run(tasks: Iterable[Task]):
make_nothin(mx.array([1])),
):
mlx_runner.main(
RunnerOpts(None, None),
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,

2
uv.lock generated
View File

@@ -469,7 +469,7 @@ requires-dist = [
[[package]]
name = "exo-pyo3-bindings"
version = "0.1.0"
version = "0.2.0"
source = { editable = "rust/exo_pyo3_bindings" }
[package.dev-dependencies]