Compare commits

..

3 Commits

Author SHA1 Message Date
Evan
b0da9dd56b runner opts 2026-02-26 17:51:31 +00:00
rltakashige
152a27ea5d Fix pipeline mismatched send after 1587 (#1629)
## Motivation

Tests caught a bug. It was a real bug.
2026-02-26 16:48:34 +00:00
rltakashige
db36bd5ac6 Add custom prefill for pipeline (#1587)
## Motivation

Since we need to do distributed communications between prefill step
sizes, the out-of-the-box stream_generate that we currently use prevents
pipeline parallel models from doing overlapped computation. While this
was technically a regression, this communication is necessary for
cancellation, and we will need various distributed communications in the
future (e.g. for coordinating batching).

500 lines are for one testing file, so the diffs aren't as bad as they
look!

## Changes

Added a special prefill function for pipeline parallel models
Edited the model to handle 
Added a test to verify this new prefill and the original prefill produce
identical results
Improved type stubs to remove some type: ignores 

## Why It Works
<img width="768" height="1246" alt="image"
src="https://github.com/user-attachments/assets/8986ff17-ac23-4a02-9bd7-e6253a0ca799"
/>

## Test Plan

### Manual Testing
Needs more testing, but seems good so far.

### Automated Testing
Passes CI, considerable speedup seen in benchmarks (up to 1.98x) on
prefill speed.

Before:
<img width="3280" height="1238" alt="image"
src="https://github.com/user-attachments/assets/9abc1cbc-ecdb-4e48-a675-2c4cb04a32a0"
/>


After:
<img width="3344" height="1236" alt="image"
src="https://github.com/user-attachments/assets/e03c7987-41b4-4950-9ac3-2840e774ce30"
/>
2026-02-26 16:00:38 +00:00
24 changed files with 1350 additions and 197 deletions

View File

@@ -73,9 +73,11 @@ class GenerationResponse:
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
prompt_cache: Any,
quantized_kv_start: int | None,
kv_group_size: int | None,
kv_bits: int | None,
) -> None: ...
def generate_step(
prompt: mx.array,
model: nn.Module,

View File

@@ -16,7 +16,7 @@ class Cache(Protocol):
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@property
def meta_state(self) -> Literal[""]: ...
@meta_state.setter
def meta_state(self, v) -> None: ...
def trim(self, n: int) -> int: ...
def is_trimmable(self) -> Literal[False]: ...
@classmethod
def from_state(cls, state, meta_state) -> Self: ...
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
@property
def state(
self,
) -> tuple[array, array]: ...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
@property
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
...
def __getitem__(self, idx): ...
@property
def state(self): # -> list[Any | array] | list[array]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@@ -253,10 +243,9 @@ class CacheList(_BaseCache):
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
def trim(self, n: int) -> int: ...
@property
def state(self): # -> list[Any]:
...
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
@state.setter
def state(self, v): # -> None:
...

View File

@@ -0,0 +1,382 @@
use crate::ext::MultiaddrExt;
use delegate::delegate;
use either::Either;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -3,6 +3,7 @@
//! this is here as a placeholder documentation
//!
//!
pub mod discovery;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.

View File

@@ -1,11 +1,9 @@
use std::collections::HashSet;
use std::pin::Pin;
use crate::alias;
use crate::swarm::transport::tcp_transport;
use crate::{alias, discovery};
pub use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::{Stream, StreamExt};
use libp2p::mdns;
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::{mpsc, oneshot};
@@ -70,7 +68,7 @@ impl Swarm {
}
event = swarm.next() => {
let Some(event) = event else { break };
for item in filter_swarm_event(event) {
if let Some(item) = filter_swarm_event(event) {
yield item;
}
}
@@ -117,7 +115,7 @@ fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
}
}
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Vec<FromSwarm> {
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
match event {
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message:
@@ -128,28 +126,19 @@ fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Vec<FromSwarm> {
..
},
..
})) => vec![FromSwarm::Message {
})) => Some(FromSwarm::Message {
from: peer_id,
topic: topic.into_string(),
data,
}],
SwarmEvent::Behaviour(BehaviourEvent::Discovery(mdns::Event::Discovered(peer_id))) => {
peer_id
.into_iter()
.map(|(pid, _)| pid)
.collect::<HashSet<PeerId>>()
.into_iter()
.map(|peer_id| FromSwarm::Discovered { peer_id })
.collect()
}
SwarmEvent::Behaviour(BehaviourEvent::Discovery(mdns::Event::Expired(peer_id))) => peer_id
.into_iter()
.map(|(pid, _)| pid)
.collect::<HashSet<PeerId>>()
.into_iter()
.map(|peer_id| FromSwarm::Discovered { peer_id })
.collect(),
_ => vec![],
}),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionEstablished { peer_id, .. },
)) => Some(FromSwarm::Discovered { peer_id }),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
peer_id,
..
})) => Some(FromSwarm::Expired { peer_id }),
_ => None,
}
}
@@ -244,34 +233,27 @@ mod transport {
}
mod behaviour {
use crate::alias;
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity, mdns};
use libp2p::{gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: mdns::tokio::Behaviour,
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: mdns_behaviour(keypair)?,
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> alias::AnyResult<mdns::tokio::Behaviour> {
Ok(mdns::tokio::Behaviour::new(
mdns::Config::default(),
keypair.public().to_peer_id(),
)?)
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};

View File

@@ -314,9 +314,13 @@ async def fetch_file_list_with_cache(
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
logger.opt(exception=e).warning(
"Ran into exception when fetching file list from HF."
)
if await aios.path.exists(cache_file):
logger.warning(
f"No internet and no cached file list for {model_id} - using local file list"
f"No cached file list for {model_id} - using local file list"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())

View File

@@ -25,6 +25,7 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
from exo.worker.runner.runner_opts import RunnerOpts
@dataclass
@@ -40,10 +41,11 @@ class Node:
node_id: NodeId
offline: bool
runner_opts: RunnerOpts
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@classmethod
async def create(cls, args: "Args") -> Self:
@staticmethod
async def create(args: "Args") -> "Node":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -63,14 +65,28 @@ class Node:
logger.info(f"Starting node {node_id}")
if args.fast_synch is True:
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
logger.info("FAST_SYNCH forced OFF")
runner_opts = RunnerOpts(
fast_synch_override=args.fast_synch,
trust_remote_code_override=args.trust_remote_code,
)
if offline := args.offline:
logger.info(
"Running in OFFLINE mode — no internet checks, local models only"
)
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(offline=args.offline),
exo_shard_downloader(offline=offline),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
offline=args.offline,
offline=offline,
)
else:
download_coordinator = None
@@ -90,6 +106,7 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
runner_opts,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
command_sender=router.sender(topics.COMMANDS),
@@ -123,7 +140,7 @@ class Node:
election_result_sender=er_send,
)
return cls(
return Node(
router,
event_router,
download_coordinator,
@@ -134,6 +151,7 @@ class Node:
api,
node_id,
args.offline,
runner_opts,
)
async def run(self):
@@ -238,6 +256,7 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
self.runner_opts,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
command_sender=self.router.sender(topics.COMMANDS),
@@ -265,17 +284,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
try:
anyio.run(node.run)
@@ -297,8 +305,11 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool | None = (
None # None = auto, True = force on, False = force off
)
@classmethod
def parse(cls) -> Self:
@@ -365,6 +376,20 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
trust_remote_code_group = parser.add_mutually_exclusive_group()
trust_remote_code_group.add_argument(
"--trust-remote-code",
action="store_true",
dest="trust_remote_code",
default=None,
help="Allow all models to execute custom code",
)
trust_remote_code_group.add_argument(
"--never-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Deny all models from execute custom code",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -2,6 +2,8 @@
from collections.abc import Sequence
from mlx import core as mx
from mlx import nn as nn
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: KVCacheType | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -1,17 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -49,6 +49,21 @@ TimeoutCallback = Callable[[], None]
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
def flush_prefill_sends() -> None:
for output, dst, group in _pending_prefill_sends:
sent = mx.distributed.send(output, dst, group=group)
mx.async_eval(sent)
_pending_prefill_sends.clear()
def clear_prefill_sends() -> None:
# Discard pending sends (e.g. on cancellation).
_pending_prefill_sends.clear()
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
@@ -150,6 +165,7 @@ class PipelineLastLayer(CustomMlxLayer):
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
self.is_prefill: bool = False
self.queue_sends: bool = False
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
@@ -163,9 +179,14 @@ class PipelineLastLayer(CustomMlxLayer):
mx.eval(output)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if self.queue_sends:
_pending_prefill_sends.append(
(output, (self.r + 1) % self.s, self.group)
)
else:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
@@ -190,6 +211,12 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
layer.is_prefill = is_prefill
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
for layer in model.layers: # type: ignore
if isinstance(layer, PipelineLastLayer):
layer.queue_sends = queue_sends
def get_inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
if isinstance(inner, nn.Module):

View File

@@ -13,8 +13,7 @@ from mlx_lm.models.cache import (
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.shared.types.mlx import KVCacheType, Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
@@ -254,9 +253,9 @@ def trim_cache(
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
c.state = [None] * len(c.state)
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
c.trim(num_tokens)
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:

View File

@@ -1,10 +1,14 @@
import functools
import math
import time
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.generate import (
maybe_quantize_kv_cache,
stream_generate,
)
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -19,13 +23,19 @@ from exo.shared.types.api import (
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.mlx import KVCacheType, Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
clear_prefill_sends,
flush_prefill_sends,
set_pipeline_prefill,
set_pipeline_queue_sends,
)
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
@@ -56,6 +66,130 @@ class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def _has_pipeline_communication_layer(model: Model):
for layer in model.layers:
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
return True
return False
def pipeline_parallel_prefill(
model: Model,
prompt: mx.array,
prompt_cache: KVCacheType,
prefill_step_size: int,
kv_group_size: int | None,
kv_bits: int | None,
prompt_progress_callback: Callable[[int, int], None],
distributed_prompt_progress_callback: Callable[[], None] | None,
group: mx.distributed.Group,
) -> None:
"""Prefill the KV cache for pipeline parallel with overlapping stages.
Each rank processes the full prompt through its real cache, offset by leading
and trailing dummy iterations.
Total iterations per rank = N_real_chunks + world_size - 1:
- rank r leading dummies (skip_pipeline_io, throwaway cache)
- N_real_chunks real (pipeline IO active, real cache)
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
e.g.
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
iter 0: R0 real[0:4096] R1 dummy
iter 1: R0 real[4096:8192] R1 real[0:4096]
iter 2: R0 real[8192:10240] R1 real[4096:8192]
iter 3: R0 dummy R1 real[8192:10240]
This function is designed to match mlx_lm's stream_generate exactly in terms of
side effects (given the same prefill step size)
"""
prefill_step_size = prefill_step_size // min(4, group.size())
quantize_cache_fn: Callable[..., None] = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=0,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
_prompt_cache: KVCacheType = prompt_cache
rank = group.rank()
world_size = group.size()
# Build list of real prompt chunk sizes
total = len(prompt)
real_chunk_sizes: list[int] = []
remaining = total - 1
while remaining:
n = min(prefill_step_size, remaining)
real_chunk_sizes.append(n)
remaining -= n
n_real = len(real_chunk_sizes)
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
n_leading = rank
n_trailing = world_size - 1 - rank
n_total = n_leading + n_real + n_trailing
t_start = time.perf_counter()
processed = 0
logger.info(
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
)
clear_prefill_sends()
# Initial callback matching generate_step
prompt_progress_callback(0, total)
try:
with mx.stream(generation_stream):
for _ in range(n_leading):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
for i in range(n_real):
chunk_size = real_chunk_sizes[i]
model(
prompt[processed : processed + chunk_size][None],
cache=_prompt_cache,
)
quantize_cache_fn(_prompt_cache)
processed += chunk_size
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
flush_prefill_sends()
prompt_progress_callback(processed, total)
for _ in range(n_trailing):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
finally:
clear_prefill_sends()
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
for _ in range(2):
with mx.stream(generation_stream):
model(prompt[-1:][None], cache=_prompt_cache)
quantize_cache_fn(_prompt_cache)
flush_prefill_sends()
assert _prompt_cache is not None
mx.eval([c.state for c in _prompt_cache]) # type: ignore
# Final callback matching generate_step
prompt_progress_callback(total, total)
logger.info(
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
@@ -64,6 +198,7 @@ def prefill(
cache: KVCacheType,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
distributed_prompt_progress_callback: Callable[[], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -95,31 +230,57 @@ def prefill(
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
def combined_progress_callback(processed: int, total: int) -> None:
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
progress_callback(processed, total)
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
is_pipeline = _has_pipeline_communication_layer(model)
prefill_step_size = 4096
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
if is_pipeline and num_tokens >= prefill_step_size:
set_pipeline_queue_sends(model, queue_sends=True)
assert group is not None, "Pipeline prefill requires a distributed group"
pipeline_parallel_prefill(
model=model,
prompt=prompt_tokens,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)
else:
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=combined_progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
raise
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
# stream_generate added 1 extra generated token to the cache, so we should trim it.
@@ -132,7 +293,7 @@ def prefill(
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
c.trim(2)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -275,6 +436,7 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -336,6 +498,7 @@ def mlx_generate(
caches,
group,
on_prefill_progress,
distributed_prompt_progress_callback,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None

View File

@@ -40,6 +40,7 @@ from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -52,7 +53,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
LayerLoadedCallback,
TimeoutCallback,
@@ -167,10 +167,12 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
trust_remote_code: bool | None,
) -> tuple[Model, TokenizerWrapper]:
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
@@ -189,12 +191,10 @@ def load_mlx_items(
mx.eval(model)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
model = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
@@ -205,6 +205,14 @@ def load_mlx_items(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
tokenizer = load_tokenizer_for_model_id(
bound_instance.bound_shard.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code
if trust_remote_code is not None
else bound_instance.bound_shard.model_card.trust_remote_code,
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
@@ -217,9 +225,8 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> tuple[nn.Module, TokenizerWrapper]:
) -> nn.Module:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
@@ -241,8 +248,6 @@ def shard_and_load(
assert isinstance(model, nn.Module)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
@@ -281,16 +286,7 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
return model
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
import anyio
@@ -46,38 +47,34 @@ from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_opts import RunnerOpts
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@dataclass
class Worker:
def __init__(
self,
node_id: NodeId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
node_id: NodeId
runner_opts: RunnerOpts
event_receiver: Receiver[IndexedEvent]
event_sender: Sender[Event]
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand]
download_command_sender: Sender[ForwarderDownloadCommand]
state: State = field(init=False, default_factory=State)
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_system_id: SystemId = field(init=False, default_factory=SystemId)
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
# Buffer for input image chunks (for image editing)
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
init=False, default_factory=dict
)
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
self._system_id = SystemId()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
_download_backoff: KeyedBackoff[ModelId] = field(
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
)
async def run(self):
logger.info("Starting Worker")
@@ -283,6 +280,7 @@ class Worker:
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
runner_opts=self.runner_opts,
bound_instance=task.bound_instance,
event_sender=self.event_sender.clone(),
)

View File

@@ -1,4 +1,5 @@
import os
import resource
import loguru
@@ -8,10 +9,13 @@ from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from .runner_opts import RunnerOpts
logger: "loguru.Logger" = loguru.logger
def entrypoint(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
@@ -20,12 +24,17 @@ def entrypoint(
) -> None:
global logger
logger = _logger
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
fast_synch_override = runner_opts.fast_synch_override
if fast_synch_override is not None:
if fast_synch_override:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
@@ -36,7 +45,7 @@ def entrypoint(
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -1,5 +1,4 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
@@ -66,6 +65,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
@@ -183,14 +183,12 @@ def _send_image_chunk(
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,

View File

@@ -1,5 +1,4 @@
import math
import resource
import time
from collections.abc import Generator
from functools import cache
@@ -31,6 +30,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -63,7 +63,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
@@ -79,19 +78,18 @@ from exo.worker.engines.mlx.utils_mlx import (
mx_any,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
from .tool_parsers import ToolParser, make_mlx_parser
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
@@ -194,6 +192,7 @@ def main(
group,
on_timeout=on_model_load_timeout,
on_layer_loaded=on_layer_loaded,
trust_remote_code=runner_opts.trust_remote_code_override,
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
@@ -274,8 +273,6 @@ def main(
def on_prefill_progress(
processed: int,
total: int,
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
@@ -288,6 +285,11 @@ def main(
),
)
)
def distributed_prompt_progress_callback(
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
@@ -309,6 +311,7 @@ def main(
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class RunnerOpts:
fast_synch_override: bool | None
trust_remote_code_override: bool | None

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.runner_opts import RunnerOpts
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -62,6 +63,7 @@ class RunnerSupervisor:
def create(
cls,
*,
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: Sender[Event],
initialize_timeout: float = 400,
@@ -73,6 +75,7 @@ class RunnerSupervisor:
runner_process = mp.Process(
target=entrypoint,
args=(
runner_opts,
bound_instance,
ev_send,
task_recv,

View File

@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load

View File

@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
@@ -143,7 +143,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
# Cache should now hold the prompt tokens minus one
@@ -164,7 +171,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -200,7 +214,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
short_tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -245,7 +266,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -285,7 +313,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -513,7 +548,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -538,7 +582,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -0,0 +1,512 @@
# type: ignore
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
then verifies that the prompt_progress_callback sequences are identical
and that generated text matches.
"""
import json
import multiprocessing as mp
import os
import tempfile
import traceback
from typing import Any, cast
import pytest
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
TOTAL_LAYERS = 24
MAX_TOKENS = 10
SEED = 42
TEMPERATURE = 0.0
def _model_card() -> ModelCard:
return ModelCard(
model_id=ModelId(MODEL_ID),
storage_size=Memory.from_gb(12),
n_layers=TOTAL_LAYERS,
hidden_size=2880,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
)
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
base_text = "The quick brown fox jumps over the lazy dog. "
base_toks = tokenizer.encode(base_text)
repeats = (prompt_tokens // len(base_toks)) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
model=MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
seed=SEED,
)
prompt = apply_chat_template(tokenizer, task)
return prompt, task
# ---------------------------------------------------------------------------
# Single-device process: uses stream_generate path (no pipeline layers)
# ---------------------------------------------------------------------------
def _run_single_device(
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
try:
import mlx.core as mx
from mlx_lm.utils import load_model
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
build_model_path,
get_tokenizer,
)
model_path = build_model_path(ModelId(MODEL_ID))
model, _ = load_model(model_path, lazy=True, strict=False)
mx.eval(model)
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
dummy_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=TOTAL_LAYERS,
n_layers=TOTAL_LAYERS,
)
tokenizer = get_tokenizer(model_path, dummy_meta)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=None,
on_prefill_progress=on_progress,
):
generated_text += response.text
if response.finish_reason is not None:
break
# Also record the token count that prefill() received (prompt_tokens[:-1])
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Pipeline device process: uses _pipeline_prefill_cache path
# ---------------------------------------------------------------------------
def _run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
try:
import mlx.core as mx
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
group = mx.distributed.init(backend="ring", strict=True)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=TOTAL_LAYERS,
)
model, tokenizer = shard_and_load(
shard_meta, group, on_timeout=None, on_layer_loaded=None
)
model = cast(Any, model)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
def distributed_prompt_progress_callback(_group: Any = group) -> None:
from exo.worker.engines.mlx.utils_mlx import mx_any
mx_any(False, _group)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
on_prefill_progress=on_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
):
generated_text += response.text
if response.finish_reason is not None:
break
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
rank,
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
def _create_hostfile(world_size: int, base_port: int) -> str:
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
return f.name
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
"""Run single-device (stream_generate) prefill and return results."""
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join(timeout=5)
pytest.fail("Single-device process timed out")
assert not result_queue.empty(), "Single-device process produced no result"
success, data = result_queue.get()
assert success, f"Single-device process failed:\n{data}"
return data
def _run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
base_port: int,
timeout: int = 120,
) -> dict[int, dict[str, Any]]:
"""Run pipeline prefill across ranks and return per-rank results."""
world_size = len(layer_splits)
hostfile_path = _create_hostfile(world_size, base_port)
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
try:
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=_run_pipeline_device,
args=(
rank,
world_size,
hostfile_path,
layer_splits,
prompt_tokens,
result_queue,
),
)
p.start()
processes.append(p)
for p in processes:
p.join(timeout=timeout)
timed_out = any(p.is_alive() for p in processes)
for p in processes:
if p.is_alive():
p.terminate()
p.join(timeout=5)
assert not timed_out, "Pipeline processes timed out"
results: dict[int, dict[str, Any]] = {}
while not result_queue.empty():
rank, success, data = result_queue.get()
assert success, f"Pipeline rank {rank} failed:\n{data}"
results[rank] = data
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
)
return results
finally:
os.unlink(hostfile_path)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not MODEL_PATH.exists(),
reason=f"GPT-OSS model not found at {MODEL_PATH}",
),
]
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
class TestPipelineNoDeadlock:
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
@pytest.mark.parametrize(
"layer_splits,prompt_tokens",
[
(LAYER_SPLITS_2WAY, 128),
(LAYER_SPLITS_2WAY, 4096),
(LAYER_SPLITS_2WAY, 8192),
(LAYER_SPLITS_2WAY, 16384),
(LAYER_SPLITS_4WAY, 128),
(LAYER_SPLITS_4WAY, 4096),
(LAYER_SPLITS_4WAY, 8192),
(LAYER_SPLITS_4WAY, 16384),
],
ids=[
"2rank_128tok",
"2rank_4096tok",
"2rank_8192tok",
"2rank_16384tok",
"4rank_128tok",
"4rank_4096tok",
"4rank_8192tok",
"4rank_16384tok",
],
)
def test_no_deadlock(
self,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
) -> None:
"""Pipeline must complete without deadlock at various prompt lengths."""
pipeline_results = _run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=prompt_tokens,
base_port=29650,
timeout=60,
)
# If we get here, no deadlock. Verify all ranks produced output.
for rank, pipe_data in sorted(pipeline_results.items()):
assert pipe_data["text"], f"Rank {rank} produced no output text"
class TestPipelinePrefillCallbacks:
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500, 5000],
ids=["short_50", "medium_500", "long_5000"],
)
def test_callbacks_match(self, prompt_tokens: int) -> None:
"""All pipeline ranks must produce identical callback sequences."""
# Run 4-rank pipeline
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29700,
timeout=180,
)
# All ranks must agree on prefill token count and callback sequence
rank0_data = pipeline_results[0]
rank0_callbacks = rank0_data["callbacks"]
prefill_count = rank0_data["prefill_token_count"]
for rank, pipe_data in sorted(pipeline_results.items()):
pipe_callbacks = pipe_data["callbacks"]
assert pipe_data["prefill_token_count"] == prefill_count, (
f"Rank {rank} prefill token count mismatch: "
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
)
assert pipe_callbacks == rank0_callbacks, (
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
f"(prefill M={prefill_count}):\n"
f" pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\n"
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
)
# Structural checks: starts with (0, M), ends with (M, M), monotonically increasing
assert rank0_callbacks[0] == (0, prefill_count), (
f"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}"
)
assert rank0_callbacks[-1] == (prefill_count, prefill_count), (
f"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}"
)
for i in range(1, len(rank0_callbacks)):
assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (
f"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}"
)
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500],
ids=["short_50", "medium_500"],
)
def test_output_matches(self, prompt_tokens: int) -> None:
"""Pipeline-generated text must match single-device output."""
single = _run_single_device_test(prompt_tokens, timeout=180)
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29800,
timeout=180,
)
single_text = single["text"]
# The last rank produces the final logits, so its output should match.
# Due to SDPA tiling non-determinism, allow minor differences in text.
last_rank = max(pipeline_results.keys())
pipe_text = pipeline_results[last_rank]["text"]
# For deterministic sampling (temp=0.0), outputs should match exactly
# or be very close. Log both for debugging even if they match.
if single_text != pipe_text:
# Find first divergence point
min_len = min(len(single_text), len(pipe_text))
diverge_idx = next(
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
min_len,
)
pytest.fail(
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
f" single-device: {single_text!r}\n"
f" pipeline R{last_rank}: {pipe_text!r}"
)
class TestPipelineCallbacksStructure:
"""Verify structural properties of callbacks independent of model output."""
def test_callback_structure_matches_generate_step(self) -> None:
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
prompt_tokens = 200
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29900,
timeout=180,
)
for rank, pipe_data in sorted(pipeline_results.items()):
callbacks = pipe_data["callbacks"]
m = pipe_data["prefill_token_count"]
assert m > 0, f"Rank {rank}: prefill token count is 0"
assert callbacks[0] == (0, m), (
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
)
assert callbacks[-1] == (m, m), (
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
)
if len(callbacks) > 2:
second_to_last = callbacks[-2]
assert second_to_last[0] < m, (
f"Rank {rank}: second-to-last callback should report < {m}, "
f"got {second_to_last}"
)
# All callbacks must have total == M
for i, (_, total) in enumerate(callbacks):
assert total == m, (
f"Rank {rank}: callback {i} has total={total}, expected {m}"
)
# processed values must be non-decreasing
processed_vals = [p for p, _ in callbacks]
for i in range(1, len(processed_vals)):
assert processed_vals[i] >= processed_vals[i - 1], (
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
f"{processed_vals}"
)
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
for i in range(1, len(callbacks)):
assert callbacks[i] != callbacks[i - 1], (
f"Rank {rank}: duplicate consecutive callback at index {i}: "
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
)

View File

@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (

View File

@@ -40,6 +40,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.runner.runner_opts import RunnerOpts
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -184,6 +185,7 @@ def _run(tasks: Iterable[Task]):
make_nothin(mx.array([1])),
):
mlx_runner.main(
RunnerOpts(None, None),
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,