mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-27 11:46:14 -05:00
Compare commits
1 Commits
leo/prepar
...
remove-cus
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93ac180120 |
@@ -73,11 +73,9 @@ class GenerationResponse:
|
||||
finish_reason: Optional[str] = ...
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: Any,
|
||||
quantized_kv_start: int | None,
|
||||
kv_group_size: int | None,
|
||||
kv_bits: int | None,
|
||||
) -> None: ...
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
): # -> None:
|
||||
...
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
|
||||
@@ -16,7 +16,7 @@ class Cache(Protocol):
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
|
||||
@@ -92,14 +92,13 @@ class _BaseCache(Cache):
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
@property
|
||||
def meta_state(self) -> Literal[""]: ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v) -> None: ...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def is_trimmable(self) -> Literal[False]: ...
|
||||
@classmethod
|
||||
def from_state(cls, state, meta_state) -> Self: ...
|
||||
@@ -115,13 +114,15 @@ class ConcatenateKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -131,7 +132,10 @@ class QuantizedKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -143,7 +147,8 @@ class QuantizedKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -155,12 +160,13 @@ class KVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
) -> tuple[array, array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -177,7 +183,8 @@ class RotatingKVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -189,7 +196,8 @@ class RotatingKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -204,7 +212,8 @@ class ArraysCache(_BaseCache):
|
||||
...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -230,7 +239,8 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@@ -243,9 +253,10 @@ class CacheList(_BaseCache):
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
|
||||
@@ -1,382 +0,0 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
pub mod discovery;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use std::collections::HashSet;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use libp2p::mdns;
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
@@ -68,7 +70,7 @@ impl Swarm {
|
||||
}
|
||||
event = swarm.next() => {
|
||||
let Some(event) = event else { break };
|
||||
if let Some(item) = filter_swarm_event(event) {
|
||||
for item in filter_swarm_event(event) {
|
||||
yield item;
|
||||
}
|
||||
}
|
||||
@@ -115,7 +117,7 @@ fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Vec<FromSwarm> {
|
||||
match event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
@@ -126,19 +128,28 @@ fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Some(FromSwarm::Message {
|
||||
})) => vec![FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
}),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Some(FromSwarm::Discovered { peer_id }),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
|
||||
peer_id,
|
||||
..
|
||||
})) => Some(FromSwarm::Expired { peer_id }),
|
||||
_ => None,
|
||||
}],
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(mdns::Event::Discovered(peer_id))) => {
|
||||
peer_id
|
||||
.into_iter()
|
||||
.map(|(pid, _)| pid)
|
||||
.collect::<HashSet<PeerId>>()
|
||||
.into_iter()
|
||||
.map(|peer_id| FromSwarm::Discovered { peer_id })
|
||||
.collect()
|
||||
}
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(mdns::Event::Expired(peer_id))) => peer_id
|
||||
.into_iter()
|
||||
.map(|(pid, _)| pid)
|
||||
.collect::<HashSet<PeerId>>()
|
||||
.into_iter()
|
||||
.map(|peer_id| FromSwarm::Discovered { peer_id })
|
||||
.collect(),
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,27 +244,34 @@ mod transport {
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use crate::alias;
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
use libp2p::{gossipsub, identity, mdns};
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub discovery: mdns::tokio::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
discovery: mdns_behaviour(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> alias::AnyResult<mdns::tokio::Behaviour> {
|
||||
Ok(mdns::tokio::Behaviour::new(
|
||||
mdns::Config::default(),
|
||||
keypair.public().to_peer_id(),
|
||||
)?)
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
|
||||
@@ -314,13 +314,9 @@ async def fetch_file_list_with_cache(
|
||||
_fetched_file_lists_this_session.add(cache_key)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
"Ran into exception when fetching file list from HF."
|
||||
)
|
||||
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"No cached file list for {model_id} - using local file list"
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
|
||||
@@ -258,6 +258,6 @@ def get_node_id_keypair(
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate()
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx import core as mx
|
||||
from mlx import nn as nn
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
@@ -16,16 +14,3 @@ from mlx_lm.models.cache import (
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
|
||||
|
||||
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: KVCacheType | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from math import inf
|
||||
from multiprocessing.synchronize import Event
|
||||
@@ -283,54 +282,6 @@ class MpReceiver[T]:
|
||||
return d
|
||||
|
||||
|
||||
class NonBlockingGenerator[T](Generator[T | None, None, None]):
|
||||
def __init__(self, source: MpReceiver[T] | Generator[T | None, None, None]) -> None:
|
||||
self._receiver: MpReceiver[T] | None = None
|
||||
self._inner: Generator[T | None, None, None] | None = None
|
||||
if isinstance(source, MpReceiver):
|
||||
self._receiver = source
|
||||
else:
|
||||
self._inner = source
|
||||
self._exhausted = False
|
||||
|
||||
def send(self, value: None, /) -> T | None:
|
||||
if self._exhausted:
|
||||
raise StopIteration
|
||||
if self._inner is not None:
|
||||
try:
|
||||
return next(self._inner)
|
||||
except (StopIteration, ClosedResourceError):
|
||||
self._exhausted = True
|
||||
raise StopIteration from None
|
||||
assert self._receiver is not None
|
||||
try:
|
||||
return self._receiver.receive_nowait()
|
||||
except WouldBlock:
|
||||
return None
|
||||
except (EndOfStream, ClosedResourceError):
|
||||
self._exhausted = True
|
||||
raise StopIteration from None
|
||||
|
||||
def throw(
|
||||
self,
|
||||
typ: type[BaseException] | BaseException,
|
||||
val: BaseException | object = None,
|
||||
tb: TracebackType | None = None,
|
||||
/,
|
||||
) -> T | None:
|
||||
raise StopIteration
|
||||
|
||||
@property
|
||||
def is_exhausted(self) -> bool:
|
||||
return self._exhausted
|
||||
|
||||
def try_receive(self) -> T | None:
|
||||
try:
|
||||
return next(self)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class channel[T]: # noqa: N801
|
||||
"""Create a pair of asynchronous channels for communicating within the same process"""
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: list[KVCache] | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -49,21 +49,6 @@ TimeoutCallback = Callable[[], None]
|
||||
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
|
||||
|
||||
|
||||
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
|
||||
|
||||
|
||||
def flush_prefill_sends() -> None:
|
||||
for output, dst, group in _pending_prefill_sends:
|
||||
sent = mx.distributed.send(output, dst, group=group)
|
||||
mx.async_eval(sent)
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def clear_prefill_sends() -> None:
|
||||
# Discard pending sends (e.g. on cancellation).
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
@@ -165,7 +150,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
self.is_prefill: bool = False
|
||||
self.queue_sends: bool = False
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
@@ -179,14 +163,9 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
if self.queue_sends:
|
||||
_pending_prefill_sends.append(
|
||||
(output, (self.r + 1) % self.s, self.group)
|
||||
)
|
||||
else:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
@@ -211,12 +190,6 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
|
||||
layer.is_prefill = is_prefill
|
||||
|
||||
|
||||
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
|
||||
for layer in model.layers: # type: ignore
|
||||
if isinstance(layer, PipelineLastLayer):
|
||||
layer.queue_sends = queue_sends
|
||||
|
||||
|
||||
def get_inner_model(model: nn.Module) -> nn.Module:
|
||||
inner = getattr(model, "model", None)
|
||||
if isinstance(inner, nn.Module):
|
||||
|
||||
@@ -13,7 +13,8 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType, Model
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -253,9 +254,9 @@ def trim_cache(
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state)
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
else:
|
||||
c.trim(num_tokens)
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import functools
|
||||
import math
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import (
|
||||
maybe_quantize_kv_cache,
|
||||
stream_generate,
|
||||
)
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -23,19 +19,13 @@ from exo.shared.types.api import (
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType, Model
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
clear_prefill_sends,
|
||||
flush_prefill_sends,
|
||||
set_pipeline_prefill,
|
||||
set_pipeline_queue_sends,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
@@ -66,130 +56,6 @@ class PrefillCancelled(BaseException):
|
||||
"""Raised when prefill is cancelled via the progress callback."""
|
||||
|
||||
|
||||
def _has_pipeline_communication_layer(model: Model):
|
||||
for layer in model.layers:
|
||||
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pipeline_parallel_prefill(
|
||||
model: Model,
|
||||
prompt: mx.array,
|
||||
prompt_cache: KVCacheType,
|
||||
prefill_step_size: int,
|
||||
kv_group_size: int | None,
|
||||
kv_bits: int | None,
|
||||
prompt_progress_callback: Callable[[int, int], None],
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None,
|
||||
group: mx.distributed.Group,
|
||||
) -> None:
|
||||
"""Prefill the KV cache for pipeline parallel with overlapping stages.
|
||||
|
||||
Each rank processes the full prompt through its real cache, offset by leading
|
||||
and trailing dummy iterations.
|
||||
|
||||
Total iterations per rank = N_real_chunks + world_size - 1:
|
||||
- rank r leading dummies (skip_pipeline_io, throwaway cache)
|
||||
- N_real_chunks real (pipeline IO active, real cache)
|
||||
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
|
||||
|
||||
e.g.
|
||||
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
|
||||
iter 0: R0 real[0:4096] R1 dummy
|
||||
iter 1: R0 real[4096:8192] R1 real[0:4096]
|
||||
iter 2: R0 real[8192:10240] R1 real[4096:8192]
|
||||
iter 3: R0 dummy R1 real[8192:10240]
|
||||
|
||||
This function is designed to match mlx_lm's stream_generate exactly in terms of
|
||||
side effects (given the same prefill step size)
|
||||
"""
|
||||
prefill_step_size = prefill_step_size // min(4, group.size())
|
||||
|
||||
quantize_cache_fn: Callable[..., None] = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=0,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
_prompt_cache: KVCacheType = prompt_cache
|
||||
rank = group.rank()
|
||||
world_size = group.size()
|
||||
|
||||
# Build list of real prompt chunk sizes
|
||||
total = len(prompt)
|
||||
real_chunk_sizes: list[int] = []
|
||||
remaining = total - 1
|
||||
while remaining:
|
||||
n = min(prefill_step_size, remaining)
|
||||
real_chunk_sizes.append(n)
|
||||
remaining -= n
|
||||
n_real = len(real_chunk_sizes)
|
||||
|
||||
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
|
||||
n_leading = rank
|
||||
n_trailing = world_size - 1 - rank
|
||||
n_total = n_leading + n_real + n_trailing
|
||||
|
||||
t_start = time.perf_counter()
|
||||
processed = 0
|
||||
logger.info(
|
||||
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
|
||||
)
|
||||
clear_prefill_sends()
|
||||
|
||||
# Initial callback matching generate_step
|
||||
prompt_progress_callback(0, total)
|
||||
|
||||
try:
|
||||
with mx.stream(generation_stream):
|
||||
for _ in range(n_leading):
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
for i in range(n_real):
|
||||
chunk_size = real_chunk_sizes[i]
|
||||
model(
|
||||
prompt[processed : processed + chunk_size][None],
|
||||
cache=_prompt_cache,
|
||||
)
|
||||
quantize_cache_fn(_prompt_cache)
|
||||
processed += chunk_size
|
||||
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
flush_prefill_sends()
|
||||
|
||||
prompt_progress_callback(processed, total)
|
||||
|
||||
for _ in range(n_trailing):
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
finally:
|
||||
clear_prefill_sends()
|
||||
|
||||
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
|
||||
for _ in range(2):
|
||||
with mx.stream(generation_stream):
|
||||
model(prompt[-1:][None], cache=_prompt_cache)
|
||||
quantize_cache_fn(_prompt_cache)
|
||||
flush_prefill_sends()
|
||||
|
||||
assert _prompt_cache is not None
|
||||
mx.eval([c.state for c in _prompt_cache]) # type: ignore
|
||||
|
||||
# Final callback matching generate_step
|
||||
prompt_progress_callback(total, total)
|
||||
|
||||
logger.info(
|
||||
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
|
||||
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
|
||||
)
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -198,7 +64,6 @@ def prefill(
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None,
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -230,57 +95,31 @@ def prefill(
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
def combined_progress_callback(processed: int, total: int) -> None:
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
progress_callback(processed, total)
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
is_pipeline = _has_pipeline_communication_layer(model)
|
||||
|
||||
prefill_step_size = 4096
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
try:
|
||||
if is_pipeline and num_tokens >= prefill_step_size:
|
||||
set_pipeline_queue_sends(model, queue_sends=True)
|
||||
assert group is not None, "Pipeline prefill requires a distributed group"
|
||||
pipeline_parallel_prefill(
|
||||
model=model,
|
||||
prompt=prompt_tokens,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
group=group,
|
||||
)
|
||||
else:
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=combined_progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=4096,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_queue_sends(model, queue_sends=False)
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
set_pipeline_queue_sends(model, queue_sends=False)
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
@@ -293,7 +132,7 @@ def prefill(
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2)
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -436,8 +275,6 @@ def mlx_generate(
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
||||
on_generation_token: Callable[[], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -499,7 +336,6 @@ def mlx_generate(
|
||||
caches,
|
||||
group,
|
||||
on_prefill_progress,
|
||||
distributed_prompt_progress_callback,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
@@ -645,9 +481,6 @@ def mlx_generate(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
if on_generation_token is not None:
|
||||
on_generation_token()
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
|
||||
@@ -40,7 +40,6 @@ from pydantic import RootModel
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
@@ -53,6 +52,7 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
LayerLoadedCallback,
|
||||
TimeoutCallback,
|
||||
|
||||
@@ -297,10 +297,10 @@ def _pending_tasks(
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -33,15 +33,10 @@ def entrypoint(
|
||||
try:
|
||||
if bound_instance.is_image_model:
|
||||
from exo.worker.runner.image_models.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
else:
|
||||
from exo.worker.runner.llm_inference.runner import Runner
|
||||
from exo.worker.runner.llm_inference.runner import main
|
||||
|
||||
runner = Runner(
|
||||
bound_instance, event_sender, task_receiver, cancel_receiver
|
||||
)
|
||||
runner.main()
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.tasks import TaskId, TextGeneration
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import PrefillCancelled, mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_any,
|
||||
)
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
|
||||
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
||||
"""Check for debug prompt triggers in the input."""
|
||||
import time
|
||||
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
|
||||
|
||||
if len(task_params.input) == 0:
|
||||
return
|
||||
prompt = task_params.input[0].content
|
||||
if not prompt:
|
||||
return
|
||||
if EXO_RUNNER_MUST_FAIL in prompt:
|
||||
raise Exception("Artificial runner exception - for testing purposes only.")
|
||||
if EXO_RUNNER_MUST_OOM in prompt:
|
||||
mlx_force_oom()
|
||||
if EXO_RUNNER_MUST_TIMEOUT in prompt:
|
||||
time.sleep(100)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BatchGenerator:
|
||||
model: Model
|
||||
tokenizer: TokenizerWrapper
|
||||
group: mx.distributed.Group | None
|
||||
kv_prefix_cache: KVPrefixCache | None
|
||||
model_id: ModelId
|
||||
device_rank: int
|
||||
cancel_receiver: MpReceiver[TaskId]
|
||||
cancelled_tasks: set[TaskId]
|
||||
event_sender: MpSender[Event]
|
||||
check_for_cancel_every: int
|
||||
|
||||
_queue: deque[tuple[TextGeneration, MpSender[GenerationResponse]]] = field(
|
||||
default_factory=deque, init=False
|
||||
)
|
||||
_active: (
|
||||
tuple[
|
||||
TextGeneration,
|
||||
MpSender[GenerationResponse],
|
||||
Generator[GenerationResponse],
|
||||
]
|
||||
| None
|
||||
) = field(default=None, init=False)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
task: TextGeneration,
|
||||
sender: MpSender[GenerationResponse],
|
||||
) -> None:
|
||||
self._queue.append((task, sender))
|
||||
if self._active is None:
|
||||
self._start_next()
|
||||
|
||||
def step(self) -> None:
|
||||
if self._active is None:
|
||||
if self._queue:
|
||||
self._start_next()
|
||||
else:
|
||||
return
|
||||
|
||||
if self._active is None:
|
||||
return
|
||||
|
||||
task, sender, gen = self._active
|
||||
try:
|
||||
response = next(gen)
|
||||
sender.send(response)
|
||||
except (StopIteration, PrefillCancelled):
|
||||
sender.close()
|
||||
self._active = None
|
||||
if self._queue:
|
||||
self._start_next()
|
||||
except Exception as e:
|
||||
self._send_error(task, e)
|
||||
sender.close()
|
||||
self._active = None
|
||||
raise
|
||||
|
||||
def _start_next(self) -> None:
|
||||
task, sender = self._queue.popleft()
|
||||
try:
|
||||
gen = self._build_generator(task)
|
||||
except Exception as e:
|
||||
self._send_error(task, e)
|
||||
sender.close()
|
||||
raise
|
||||
self._active = (task, sender, gen)
|
||||
|
||||
def _send_error(self, task: TextGeneration, e: Exception) -> None:
|
||||
if self.device_rank == 0:
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
|
||||
_check_for_debug_prompts(task.task_params)
|
||||
prompt = apply_chat_template(self.tokenizer, task.task_params)
|
||||
|
||||
def on_prefill_progress(processed: int, total: int) -> None:
|
||||
if self.device_rank == 0:
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=task.command_id,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=self.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def distributed_prompt_progress_callback() -> None:
|
||||
self.cancelled_tasks.update(self.cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, self.group):
|
||||
raise PrefillCancelled()
|
||||
|
||||
tokens_since_cancel_check = self.check_for_cancel_every
|
||||
|
||||
def on_generation_token() -> None:
|
||||
nonlocal tokens_since_cancel_check
|
||||
tokens_since_cancel_check += 1
|
||||
if tokens_since_cancel_check >= self.check_for_cancel_every:
|
||||
tokens_since_cancel_check = 0
|
||||
self.cancelled_tasks.update(self.cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, self.group):
|
||||
raise PrefillCancelled()
|
||||
|
||||
return mlx_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
task=task.task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=self.kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
on_generation_token=on_generation_token,
|
||||
group=self.group,
|
||||
)
|
||||
@@ -1,341 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.llm_inference.tool_parsers import ToolParser
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse | None],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
current_tool_name: str | None = None
|
||||
tool_arg_parts: list[str] = []
|
||||
|
||||
for response in responses:
|
||||
if response is None:
|
||||
yield None
|
||||
continue
|
||||
try:
|
||||
stream.process(response.token)
|
||||
except HarmonyError:
|
||||
logger.error("Encountered critical Harmony Error, returning early")
|
||||
return
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
# Debug: log every token with state
|
||||
logger.debug(
|
||||
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||
)
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
logger.info(
|
||||
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||
)
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
],
|
||||
usage=response.usage,
|
||||
)
|
||||
tool_arg_parts = []
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
if current_tool_name is not None:
|
||||
if delta:
|
||||
tool_arg_parts.append(delta)
|
||||
continue
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
yield response
|
||||
|
||||
|
||||
def parse_deepseek_v32(
|
||||
responses: Generator[GenerationResponse | None],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||
|
||||
Uses accumulated-text matching (not per-token marker checks) because
|
||||
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||
Also handles <think>...</think> blocks for thinking mode.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
parse_dsml_output,
|
||||
)
|
||||
|
||||
accumulated = ""
|
||||
in_tool_call = False
|
||||
thinking = False
|
||||
# Tokens buffered while we detect the start of a DSML block
|
||||
pending_buffer: list[GenerationResponse] = []
|
||||
# Text accumulated during a tool call block
|
||||
tool_call_text = ""
|
||||
|
||||
for response in responses:
|
||||
if response is None:
|
||||
yield None
|
||||
continue
|
||||
|
||||
# ── Handle thinking tags ──
|
||||
if not thinking and THINKING_START in response.text:
|
||||
thinking = True
|
||||
# Yield any text before the <think> tag
|
||||
before = response.text[: response.text.index(THINKING_START)]
|
||||
if before:
|
||||
yield response.model_copy(update={"text": before})
|
||||
continue
|
||||
|
||||
if thinking and THINKING_END in response.text:
|
||||
thinking = False
|
||||
# Yield any text after the </think> tag
|
||||
after = response.text[
|
||||
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||
]
|
||||
if after:
|
||||
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||
continue
|
||||
|
||||
if thinking:
|
||||
yield response.model_copy(update={"is_thinking": True})
|
||||
continue
|
||||
|
||||
# ── Handle tool call accumulation ──
|
||||
if in_tool_call:
|
||||
tool_call_text += response.text
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
# Parse the accumulated DSML block
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# EOS reached before end marker — yield buffered text as-is
|
||||
if response.finish_reason is not None:
|
||||
logger.info("DSML tool call parsing interrupted by EOS")
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# ── Detect start of tool call block ──
|
||||
accumulated += response.text
|
||||
|
||||
if TOOL_CALLS_START in accumulated:
|
||||
# The start marker might be split across pending_buffer + current token
|
||||
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||
# Yield any pending tokens that are purely before the marker
|
||||
pre_text = accumulated[:start_idx]
|
||||
if pre_text:
|
||||
# Flush pending buffer tokens that contributed text before the marker
|
||||
for buf_resp in pending_buffer:
|
||||
if pre_text:
|
||||
chunk = buf_resp.text
|
||||
if len(chunk) <= len(pre_text):
|
||||
yield buf_resp
|
||||
pre_text = pre_text[len(chunk) :]
|
||||
else:
|
||||
yield buf_resp.model_copy(update={"text": pre_text})
|
||||
pre_text = ""
|
||||
pending_buffer = []
|
||||
tool_call_text = accumulated[start_idx:]
|
||||
accumulated = ""
|
||||
|
||||
# Check if the end marker is already present (entire tool call in one token)
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
tool_call_text = ""
|
||||
else:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check if accumulated text might be the start of a DSML marker
|
||||
# Buffer tokens if we see a partial match at the end
|
||||
if _could_be_dsml_prefix(accumulated):
|
||||
pending_buffer.append(response)
|
||||
continue
|
||||
|
||||
# No partial match — flush all pending tokens and the current one
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
pending_buffer = []
|
||||
accumulated = ""
|
||||
yield response
|
||||
|
||||
# Flush any remaining pending buffer at generator end
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
|
||||
|
||||
def _could_be_dsml_prefix(text: str) -> bool:
|
||||
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||
|
||||
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||
|
||||
# Only check the last portion of text that could overlap with the marker
|
||||
max_check = len(TOOL_CALLS_START)
|
||||
tail = text[-max_check:] if len(text) > max_check else text
|
||||
|
||||
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||
for i in range(len(tail)):
|
||||
suffix = tail[i:]
|
||||
if TOOL_CALLS_START.startswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse | None],
|
||||
tokenizer: TokenizerWrapper,
|
||||
starts_in_thinking: bool = True,
|
||||
) -> Generator[GenerationResponse | None]:
|
||||
"""Route thinking tokens via is_thinking flag.
|
||||
|
||||
Swallows think tag tokens, sets is_thinking on all others.
|
||||
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||
"""
|
||||
in_thinking = starts_in_thinking
|
||||
for response in responses:
|
||||
if response is None:
|
||||
yield None
|
||||
continue
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
|
||||
is_think_tag = (
|
||||
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||
) or (
|
||||
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||
)
|
||||
|
||||
if is_think_tag:
|
||||
in_thinking = response.text != tokenizer.think_end
|
||||
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||
if response.finish_reason is not None:
|
||||
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||
continue
|
||||
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
|
||||
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
if response is None:
|
||||
yield None
|
||||
continue
|
||||
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
|
||||
in_tool_call = True
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.text.endswith(tool_parser.end_parsing):
|
||||
# parse the actual tool calls from the tool call text
|
||||
parsed = tool_parser.parse_tool_calls(
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if parsed is not None:
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||
)
|
||||
response.text = "".join(tool_call_text_parts)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"tool call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"text": "".join(tool_call_text_parts),
|
||||
"token": 0,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
else:
|
||||
# fallthrough
|
||||
yield response
|
||||
File diff suppressed because it is too large
Load Diff
@@ -172,7 +172,7 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending[event.task_id].set()
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -190,7 +190,6 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
self.pending.pop(event.task_id, None)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
cache_length,
|
||||
@@ -143,14 +143,7 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
@@ -171,14 +164,7 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -214,14 +200,7 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
short_tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -266,14 +245,7 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -313,14 +285,7 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -548,16 +513,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -582,16 +538,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -1,512 +0,0 @@
|
||||
# type: ignore
|
||||
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
|
||||
|
||||
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
|
||||
then verifies that the prompt_progress_callback sequences are identical
|
||||
and that generated text matches.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
|
||||
TOTAL_LAYERS = 24
|
||||
MAX_TOKENS = 10
|
||||
SEED = 42
|
||||
TEMPERATURE = 0.0
|
||||
|
||||
|
||||
def _model_card() -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(MODEL_ID),
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=TOTAL_LAYERS,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
|
||||
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_toks = tokenizer.encode(base_text)
|
||||
repeats = (prompt_tokens // len(base_toks)) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = TextGenerationTaskParams(
|
||||
model=MODEL_ID,
|
||||
input=[InputMessage(role="user", content=prompt_text)],
|
||||
max_output_tokens=MAX_TOKENS,
|
||||
temperature=TEMPERATURE,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
return prompt, task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-device process: uses stream_generate path (no pipeline layers)
|
||||
# ---------------------------------------------------------------------------
|
||||
def _run_single_device(
|
||||
prompt_tokens: int,
|
||||
result_queue: Any,
|
||||
) -> None:
|
||||
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.cache import encode_prompt
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
build_model_path,
|
||||
get_tokenizer,
|
||||
)
|
||||
|
||||
model_path = build_model_path(ModelId(MODEL_ID))
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
mx.eval(model)
|
||||
|
||||
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
|
||||
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
|
||||
dummy_meta = PipelineShardMetadata(
|
||||
model_card=_model_card(),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=TOTAL_LAYERS,
|
||||
n_layers=TOTAL_LAYERS,
|
||||
)
|
||||
tokenizer = get_tokenizer(model_path, dummy_meta)
|
||||
|
||||
prompt, task = _build_prompt(tokenizer, prompt_tokens)
|
||||
|
||||
callbacks: list[tuple[int, int]] = []
|
||||
|
||||
def on_progress(processed: int, total: int) -> None:
|
||||
callbacks.append((processed, total))
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=None,
|
||||
on_prefill_progress=on_progress,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
# Also record the token count that prefill() received (prompt_tokens[:-1])
|
||||
all_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefill_token_count = len(all_tokens) - 1
|
||||
|
||||
result_queue.put(
|
||||
(
|
||||
True,
|
||||
{
|
||||
"callbacks": callbacks,
|
||||
"text": generated_text,
|
||||
"prefill_token_count": prefill_token_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline device process: uses _pipeline_prefill_cache path
|
||||
# ---------------------------------------------------------------------------
|
||||
def _run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
result_queue: Any,
|
||||
) -> None:
|
||||
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.cache import encode_prompt
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_card=_model_card(),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=TOTAL_LAYERS,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(
|
||||
shard_meta, group, on_timeout=None, on_layer_loaded=None
|
||||
)
|
||||
model = cast(Any, model)
|
||||
|
||||
prompt, task = _build_prompt(tokenizer, prompt_tokens)
|
||||
|
||||
callbacks: list[tuple[int, int]] = []
|
||||
|
||||
def on_progress(processed: int, total: int) -> None:
|
||||
callbacks.append((processed, total))
|
||||
|
||||
def distributed_prompt_progress_callback(_group: Any = group) -> None:
|
||||
from exo.worker.engines.mlx.utils_mlx import mx_any
|
||||
|
||||
mx_any(False, _group)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
on_prefill_progress=on_progress,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
all_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefill_token_count = len(all_tokens) - 1
|
||||
|
||||
result_queue.put(
|
||||
(
|
||||
rank,
|
||||
True,
|
||||
{
|
||||
"callbacks": callbacks,
|
||||
"text": generated_text,
|
||||
"prefill_token_count": prefill_token_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _create_hostfile(world_size: int, base_port: int) -> str:
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
return f.name
|
||||
|
||||
|
||||
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
|
||||
"""Run single-device (stream_generate) prefill and return results."""
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
|
||||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=5)
|
||||
pytest.fail("Single-device process timed out")
|
||||
|
||||
assert not result_queue.empty(), "Single-device process produced no result"
|
||||
success, data = result_queue.get()
|
||||
assert success, f"Single-device process failed:\n{data}"
|
||||
return data
|
||||
|
||||
|
||||
def _run_pipeline_test(
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
base_port: int,
|
||||
timeout: int = 120,
|
||||
) -> dict[int, dict[str, Any]]:
|
||||
"""Run pipeline prefill across ranks and return per-rank results."""
|
||||
world_size = len(layer_splits)
|
||||
hostfile_path = _create_hostfile(world_size, base_port)
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
try:
|
||||
processes: list[Any] = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(
|
||||
target=_run_pipeline_device,
|
||||
args=(
|
||||
rank,
|
||||
world_size,
|
||||
hostfile_path,
|
||||
layer_splits,
|
||||
prompt_tokens,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join(timeout=timeout)
|
||||
|
||||
timed_out = any(p.is_alive() for p in processes)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=5)
|
||||
|
||||
assert not timed_out, "Pipeline processes timed out"
|
||||
|
||||
results: dict[int, dict[str, Any]] = {}
|
||||
while not result_queue.empty():
|
||||
rank, success, data = result_queue.get()
|
||||
assert success, f"Pipeline rank {rank} failed:\n{data}"
|
||||
results[rank] = data
|
||||
|
||||
assert len(results) == world_size, (
|
||||
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
|
||||
)
|
||||
return results
|
||||
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not MODEL_PATH.exists(),
|
||||
reason=f"GPT-OSS model not found at {MODEL_PATH}",
|
||||
),
|
||||
]
|
||||
|
||||
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
|
||||
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
|
||||
|
||||
|
||||
class TestPipelineNoDeadlock:
|
||||
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_splits,prompt_tokens",
|
||||
[
|
||||
(LAYER_SPLITS_2WAY, 128),
|
||||
(LAYER_SPLITS_2WAY, 4096),
|
||||
(LAYER_SPLITS_2WAY, 8192),
|
||||
(LAYER_SPLITS_2WAY, 16384),
|
||||
(LAYER_SPLITS_4WAY, 128),
|
||||
(LAYER_SPLITS_4WAY, 4096),
|
||||
(LAYER_SPLITS_4WAY, 8192),
|
||||
(LAYER_SPLITS_4WAY, 16384),
|
||||
],
|
||||
ids=[
|
||||
"2rank_128tok",
|
||||
"2rank_4096tok",
|
||||
"2rank_8192tok",
|
||||
"2rank_16384tok",
|
||||
"4rank_128tok",
|
||||
"4rank_4096tok",
|
||||
"4rank_8192tok",
|
||||
"4rank_16384tok",
|
||||
],
|
||||
)
|
||||
def test_no_deadlock(
|
||||
self,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
"""Pipeline must complete without deadlock at various prompt lengths."""
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=layer_splits,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29650,
|
||||
timeout=60,
|
||||
)
|
||||
# If we get here, no deadlock. Verify all ranks produced output.
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
assert pipe_data["text"], f"Rank {rank} produced no output text"
|
||||
|
||||
|
||||
class TestPipelinePrefillCallbacks:
|
||||
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_tokens",
|
||||
[50, 500, 5000],
|
||||
ids=["short_50", "medium_500", "long_5000"],
|
||||
)
|
||||
def test_callbacks_match(self, prompt_tokens: int) -> None:
|
||||
"""All pipeline ranks must produce identical callback sequences."""
|
||||
# Run 4-rank pipeline
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29700,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
# All ranks must agree on prefill token count and callback sequence
|
||||
rank0_data = pipeline_results[0]
|
||||
rank0_callbacks = rank0_data["callbacks"]
|
||||
prefill_count = rank0_data["prefill_token_count"]
|
||||
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
pipe_callbacks = pipe_data["callbacks"]
|
||||
|
||||
assert pipe_data["prefill_token_count"] == prefill_count, (
|
||||
f"Rank {rank} prefill token count mismatch: "
|
||||
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
|
||||
)
|
||||
|
||||
assert pipe_callbacks == rank0_callbacks, (
|
||||
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
|
||||
f"(prefill M={prefill_count}):\n"
|
||||
f" pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\n"
|
||||
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
|
||||
)
|
||||
|
||||
# Structural checks: starts with (0, M), ends with (M, M), monotonically increasing
|
||||
assert rank0_callbacks[0] == (0, prefill_count), (
|
||||
f"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}"
|
||||
)
|
||||
assert rank0_callbacks[-1] == (prefill_count, prefill_count), (
|
||||
f"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}"
|
||||
)
|
||||
for i in range(1, len(rank0_callbacks)):
|
||||
assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (
|
||||
f"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_tokens",
|
||||
[50, 500],
|
||||
ids=["short_50", "medium_500"],
|
||||
)
|
||||
def test_output_matches(self, prompt_tokens: int) -> None:
|
||||
"""Pipeline-generated text must match single-device output."""
|
||||
single = _run_single_device_test(prompt_tokens, timeout=180)
|
||||
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29800,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
single_text = single["text"]
|
||||
|
||||
# The last rank produces the final logits, so its output should match.
|
||||
# Due to SDPA tiling non-determinism, allow minor differences in text.
|
||||
last_rank = max(pipeline_results.keys())
|
||||
pipe_text = pipeline_results[last_rank]["text"]
|
||||
|
||||
# For deterministic sampling (temp=0.0), outputs should match exactly
|
||||
# or be very close. Log both for debugging even if they match.
|
||||
if single_text != pipe_text:
|
||||
# Find first divergence point
|
||||
min_len = min(len(single_text), len(pipe_text))
|
||||
diverge_idx = next(
|
||||
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
|
||||
min_len,
|
||||
)
|
||||
pytest.fail(
|
||||
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
|
||||
f" single-device: {single_text!r}\n"
|
||||
f" pipeline R{last_rank}: {pipe_text!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineCallbacksStructure:
|
||||
"""Verify structural properties of callbacks independent of model output."""
|
||||
|
||||
def test_callback_structure_matches_generate_step(self) -> None:
|
||||
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
|
||||
prompt_tokens = 200
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29900,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
callbacks = pipe_data["callbacks"]
|
||||
m = pipe_data["prefill_token_count"]
|
||||
assert m > 0, f"Rank {rank}: prefill token count is 0"
|
||||
|
||||
assert callbacks[0] == (0, m), (
|
||||
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
|
||||
)
|
||||
|
||||
assert callbacks[-1] == (m, m), (
|
||||
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
|
||||
)
|
||||
|
||||
if len(callbacks) > 2:
|
||||
second_to_last = callbacks[-2]
|
||||
assert second_to_last[0] < m, (
|
||||
f"Rank {rank}: second-to-last callback should report < {m}, "
|
||||
f"got {second_to_last}"
|
||||
)
|
||||
|
||||
# All callbacks must have total == M
|
||||
for i, (_, total) in enumerate(callbacks):
|
||||
assert total == m, (
|
||||
f"Rank {rank}: callback {i} has total={total}, expected {m}"
|
||||
)
|
||||
|
||||
# processed values must be non-decreasing
|
||||
processed_vals = [p for p, _ in callbacks]
|
||||
for i in range(1, len(processed_vals)):
|
||||
assert processed_vals[i] >= processed_vals[i - 1], (
|
||||
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
|
||||
f"{processed_vals}"
|
||||
)
|
||||
|
||||
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
|
||||
for i in range(1, len(callbacks)):
|
||||
assert callbacks[i] != callbacks[i - 1], (
|
||||
f"Rank {rank}: duplicate consecutive callback at index {i}: "
|
||||
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
|
||||
)
|
||||
@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
|
||||
@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
|
||||
encode_messages,
|
||||
parse_dsml_output,
|
||||
)
|
||||
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
|
||||
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Callable
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
|
||||
import exo.worker.runner.llm_inference.runner as mlx_runner
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
@@ -116,20 +115,17 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(
|
||||
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
# Use a fake event_sender to remove test flakiness.
|
||||
@@ -187,13 +183,12 @@ def _run(tasks: Iterable[Task]):
|
||||
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
runner = mlx_runner.Runner(
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
runner.main()
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
|
||||
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
|
||||
|
||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||
# These are stable since they come from the model's vocabulary.
|
||||
@@ -107,7 +107,7 @@ def _collect(
|
||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||
yield from _make_gen_responses(tokens)
|
||||
|
||||
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
|
||||
return list(parse_gpt_oss(_gen()))
|
||||
|
||||
|
||||
def _get_tool_call(
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
|
||||
from exo.worker.runner.llm_inference.runner import parse_tool_calls
|
||||
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user