mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 15:33:26 -05:00
Compare commits
1 Commits
main
...
rust-explo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3029dc6653 |
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -514,6 +514,20 @@ version = "0.7.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cluster_membership"
|
||||||
|
version = "0.0.1"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
|
"futures-lite",
|
||||||
|
"futures-timer",
|
||||||
|
"libp2p",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorchoice"
|
name = "colorchoice"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
@@ -998,6 +1012,7 @@ dependencies = [
|
|||||||
name = "exo_pyo3_bindings"
|
name = "exo_pyo3_bindings"
|
||||||
version = "0.0.1"
|
version = "0.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"cluster_membership",
|
||||||
"delegate",
|
"delegate",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
@@ -1030,6 +1045,12 @@ dependencies = [
|
|||||||
"syn 2.0.111",
|
"syn 2.0.111",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ff"
|
name = "ff"
|
||||||
version = "0.13.1"
|
version = "0.13.1"
|
||||||
@@ -1138,7 +1159,10 @@ version = "2.6.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"fastrand",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"parking",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ members = [
|
|||||||
"rust/networking",
|
"rust/networking",
|
||||||
"rust/exo_pyo3_bindings",
|
"rust/exo_pyo3_bindings",
|
||||||
"rust/util",
|
"rust/util",
|
||||||
|
"rust/cluster_membership",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
@@ -25,6 +26,7 @@ opt-level = 3
|
|||||||
## Crate members as common dependencies
|
## Crate members as common dependencies
|
||||||
networking = { path = "rust/networking" }
|
networking = { path = "rust/networking" }
|
||||||
util = { path = "rust/util" }
|
util = { path = "rust/util" }
|
||||||
|
cluster_membership = { path = "rust/cluster_membership" }
|
||||||
|
|
||||||
# Proc-macro authoring tools
|
# Proc-macro authoring tools
|
||||||
syn = "2.0"
|
syn = "2.0"
|
||||||
@@ -62,6 +64,7 @@ frunk-enum-core = "0.3"
|
|||||||
# Async dependencies
|
# Async dependencies
|
||||||
tokio = "1.46"
|
tokio = "1.46"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
futures-lite = "2.6.1"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
futures-timer = "3.0"
|
futures-timer = "3.0"
|
||||||
|
|
||||||
|
|||||||
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
|||||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
treefmt || nix fmt
|
nix fmt
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
uv run ruff check --fix
|
uv run ruff check --fix
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ dependencies = [
|
|||||||
"anyio==4.11.0",
|
"anyio==4.11.0",
|
||||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||||
"mlx-lm==0.30.5",
|
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||||
"hypercorn>=0.18.0",
|
"hypercorn>=0.18.0",
|
||||||
"openai-harmony>=0.0.8",
|
"openai-harmony>=0.0.8",
|
||||||
|
|||||||
23
rust/cluster_membership/Cargo.toml
Normal file
23
rust/cluster_membership/Cargo.toml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[package]
|
||||||
|
name = "cluster_membership"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# util
|
||||||
|
anyhow.workspace = true
|
||||||
|
log.workspace = true
|
||||||
|
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||||
|
|
||||||
|
# async
|
||||||
|
tokio = { workspace = true, features = ["full"] }
|
||||||
|
futures-timer = { workspace = true }
|
||||||
|
futures-lite = "2.6.1"
|
||||||
|
|
||||||
|
# networking
|
||||||
|
libp2p = { workspace = true, features = ["full"] }
|
||||||
|
async-trait = "0.1.89"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
30
rust/cluster_membership/examples/chatroom.rs
Normal file
30
rust/cluster_membership/examples/chatroom.rs
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
use cluster_membership::Peer;
|
||||||
|
use libp2p::identity::ed25519::SecretKey;
|
||||||
|
use tokio::io::{self, AsyncBufReadExt};
|
||||||
|
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
let (mut peer, send, mut recv) =
|
||||||
|
Peer::new(SecretKey::generate(), "hello".to_string()).expect("peer should always build");
|
||||||
|
|
||||||
|
let ch = peer.subscribe("chatroom".to_string());
|
||||||
|
let jh = tokio::spawn(async move { peer.run().await });
|
||||||
|
|
||||||
|
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Ok(Some(line)) = stdin.next_line() => {send.send((ch.clone(), line.into_bytes())).await.expect("example");}
|
||||||
|
Some(r) = recv.recv() => match r {
|
||||||
|
Ok((_, id, line)) => println!("{:?}:{:?}", id, String::from_utf8_lossy(&line)),
|
||||||
|
Err(e) => eprintln!("{e:?}"),
|
||||||
|
},
|
||||||
|
else => break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jh.await.expect("task failure");
|
||||||
|
}
|
||||||
227
rust/cluster_membership/src/lib.rs
Normal file
227
rust/cluster_membership/src/lib.rs
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
use libp2p::{
|
||||||
|
Multiaddr, PeerId, Swarm, SwarmBuilder,
|
||||||
|
futures::StreamExt,
|
||||||
|
gossipsub::{self, PublishError, Sha256Topic, TopicHash},
|
||||||
|
identify,
|
||||||
|
identity::{Keypair, ed25519},
|
||||||
|
mdns,
|
||||||
|
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
use tokio::{select, sync::mpsc};
|
||||||
|
|
||||||
|
const DEFAULT_BUFFER_SIZE: usize = 10;
|
||||||
|
const MDNS_IGNORE_DURATION_SECS: u64 = 30;
|
||||||
|
|
||||||
|
impl Peer {
|
||||||
|
pub fn new(
|
||||||
|
identity: ed25519::SecretKey,
|
||||||
|
namespace: String,
|
||||||
|
) -> anyhow::Result<(
|
||||||
|
Self,
|
||||||
|
mpsc::Sender<(TopicHash, Vec<u8>)>,
|
||||||
|
mpsc::Receiver<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||||
|
)> {
|
||||||
|
let mut id_bytes = identity.as_ref().to_vec();
|
||||||
|
|
||||||
|
let mut swarm =
|
||||||
|
SwarmBuilder::with_existing_identity(Keypair::ed25519_from_bytes(&mut id_bytes)?)
|
||||||
|
.with_tokio()
|
||||||
|
.with_quic()
|
||||||
|
// TODO(evan): .with_bandwidth_metrics();
|
||||||
|
.with_behaviour(|kp| Behaviour::new(kp, namespace.clone()))?
|
||||||
|
.build();
|
||||||
|
|
||||||
|
swarm.listen_on("/ip6/::/udp/0/quic-v1".parse()?)?;
|
||||||
|
swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?;
|
||||||
|
let (to_swarm, from_client) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||||
|
let (to_client, from_swarm) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||||
|
Ok((
|
||||||
|
Self {
|
||||||
|
swarm,
|
||||||
|
namespace,
|
||||||
|
denied: HashMap::new(),
|
||||||
|
from_client,
|
||||||
|
to_client,
|
||||||
|
},
|
||||||
|
to_swarm,
|
||||||
|
from_swarm,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn subscribe(&mut self, topic: String) -> TopicHash {
|
||||||
|
let topic = Sha256Topic::new(topic);
|
||||||
|
self.swarm
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.subscribe(&topic)
|
||||||
|
.expect("topic filtered");
|
||||||
|
topic.hash()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&mut self) {
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
ev = self.swarm.select_next_some() => {
|
||||||
|
let Ok(()) = self.handle_swarm_event(ev).await else {
|
||||||
|
return
|
||||||
|
};
|
||||||
|
},
|
||||||
|
Some(msg) = self.from_client.recv() => {
|
||||||
|
if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish(msg.0, msg.1) {
|
||||||
|
let Ok(()) = self.to_client.send(Err(e)).await else {
|
||||||
|
return
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_swarm_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
|
||||||
|
let SwarmEvent::Behaviour(event) = event else {
|
||||||
|
if let SwarmEvent::NewListenAddr {
|
||||||
|
listener_id: _,
|
||||||
|
address,
|
||||||
|
} = event
|
||||||
|
{
|
||||||
|
log::info!("new listen address {address}")
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
match event {
|
||||||
|
BehaviourEvent::Mdns(mdns_event) => match mdns_event {
|
||||||
|
mdns::Event::Discovered(vec) => {
|
||||||
|
// Dial everyone
|
||||||
|
let mut addrs = HashMap::<PeerId, Vec<Multiaddr>>::new();
|
||||||
|
vec.into_iter()
|
||||||
|
.filter(|(peer_id, _)| {
|
||||||
|
self.denied.get(peer_id).is_none_or(|t| {
|
||||||
|
t.elapsed() > Duration::from_secs(MDNS_IGNORE_DURATION_SECS)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.for_each(|(peer_id, addr)| addrs.entry(peer_id).or_default().push(addr));
|
||||||
|
addrs.into_iter().for_each(|(peer_id, addrs)| {
|
||||||
|
let _ = self
|
||||||
|
.swarm
|
||||||
|
.dial(DialOpts::peer_id(peer_id).addresses(addrs).build());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
mdns::Event::Expired(vec) => {
|
||||||
|
vec.iter().for_each(|(peer_id, _)| {
|
||||||
|
log::debug!("{peer_id} no longer reachable on mDNS");
|
||||||
|
self.swarm
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.remove_explicit_peer(peer_id);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
BehaviourEvent::Identify(identify::Event::Received {
|
||||||
|
connection_id: _,
|
||||||
|
peer_id,
|
||||||
|
info,
|
||||||
|
}) => {
|
||||||
|
if info
|
||||||
|
.protocols
|
||||||
|
.iter()
|
||||||
|
.any(|p| p.as_ref().contains(&self.namespace))
|
||||||
|
{
|
||||||
|
self.passed_namespace(peer_id);
|
||||||
|
} else {
|
||||||
|
self.failed_namespace(peer_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||||
|
propagation_source: _,
|
||||||
|
message_id: _,
|
||||||
|
message:
|
||||||
|
gossipsub::Message {
|
||||||
|
topic,
|
||||||
|
data,
|
||||||
|
source: Some(source_peer),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
}) => {
|
||||||
|
let Ok(()) = self.to_client.send(Ok((topic, source_peer, data))).await else {
|
||||||
|
return Err(());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn passed_namespace(&mut self, peer: PeerId) {
|
||||||
|
log::info!("new peer {peer:?}");
|
||||||
|
self.denied.remove(&peer);
|
||||||
|
self.swarm
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.remove_blacklisted_peer(&peer);
|
||||||
|
self.swarm
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.add_explicit_peer(&peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn failed_namespace(&mut self, peer: PeerId) {
|
||||||
|
log::debug!("{peer} failed handshake");
|
||||||
|
self.denied.insert(peer, Instant::now());
|
||||||
|
self.swarm.behaviour_mut().gossipsub.blacklist_peer(&peer);
|
||||||
|
// we don't care if disconnect fails
|
||||||
|
let _ = self.swarm.disconnect_peer_id(peer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Peer {
|
||||||
|
pub swarm: Swarm<Behaviour>,
|
||||||
|
denied: HashMap<PeerId, Instant>,
|
||||||
|
namespace: String,
|
||||||
|
to_client: mpsc::Sender<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||||
|
from_client: mpsc::Receiver<(TopicHash, Vec<u8>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn foo() {
|
||||||
|
fn bar<T: Send>(t: T) {}
|
||||||
|
let p: Peer = unimplemented!();
|
||||||
|
bar(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(NetworkBehaviour)]
|
||||||
|
pub struct Behaviour {
|
||||||
|
mdns: mdns::tokio::Behaviour,
|
||||||
|
pub gossipsub: gossipsub::Behaviour,
|
||||||
|
identify: identify::Behaviour,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Behaviour {
|
||||||
|
fn new(kp: &Keypair, namespace: String) -> Self {
|
||||||
|
let mdns = mdns::tokio::Behaviour::new(Default::default(), kp.public().to_peer_id())
|
||||||
|
.expect("implementation is infallible");
|
||||||
|
let gossipsub = gossipsub::Behaviour::new(
|
||||||
|
gossipsub::MessageAuthenticity::Signed(kp.clone()),
|
||||||
|
gossipsub::ConfigBuilder::default()
|
||||||
|
.max_transmit_size(1024 * 1024)
|
||||||
|
.protocol_id_prefix(format!("/exo/gossip/{namespace}/v1"))
|
||||||
|
.build()
|
||||||
|
.expect("fixed gossipsub config should always build"),
|
||||||
|
)
|
||||||
|
.expect("fixed gossipsub init should always build");
|
||||||
|
|
||||||
|
let identify = identify::Behaviour::new(
|
||||||
|
identify::Config::new_with_signed_peer_record(format!("/exo/identity/v1"), kp)
|
||||||
|
.with_push_listen_addr_updates(true),
|
||||||
|
);
|
||||||
|
|
||||||
|
Behaviour {
|
||||||
|
mdns,
|
||||||
|
gossipsub,
|
||||||
|
identify,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ doc = false
|
|||||||
workspace = true
|
workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
cluster_membership.workspace = true
|
||||||
networking = { workspace = true }
|
networking = { workspace = true }
|
||||||
|
|
||||||
# interop
|
# interop
|
||||||
|
|||||||
@@ -6,3 +6,41 @@
|
|||||||
|
|
||||||
pub mod ident;
|
pub mod ident;
|
||||||
pub mod multiaddr;
|
pub mod multiaddr;
|
||||||
|
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use cluster_membership::Peer;
|
||||||
|
use libp2p::identity::ed25519::Keypair;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||||
|
|
||||||
|
#[gen_stub_pyclass]
|
||||||
|
#[pyclass]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PyKeypair(Keypair);
|
||||||
|
|
||||||
|
#[gen_stub_pymethods]
|
||||||
|
#[pymethods]
|
||||||
|
impl PyKeypair {
|
||||||
|
#[staticmethod]
|
||||||
|
fn generate() -> Self {
|
||||||
|
Self(Keypair::generate())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gen_stub_pyclass]
|
||||||
|
#[pyclass]
|
||||||
|
pub struct PyPeer(Mutex<Peer>);
|
||||||
|
|
||||||
|
#[gen_stub_pymethods]
|
||||||
|
#[pymethods]
|
||||||
|
impl PyPeer {
|
||||||
|
#[staticmethod]
|
||||||
|
fn init(kp: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||||
|
Ok(PyPeer(Mutex::new(
|
||||||
|
Peer::new(kp.0.secret(), namespace)
|
||||||
|
.map_err(|e| e.pyerr())?
|
||||||
|
.0,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -121,20 +121,11 @@ async def ensure_models_dir() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
async def delete_model(model_id: ModelId) -> bool:
|
async def delete_model(model_id: ModelId) -> bool:
|
||||||
models_dir = await ensure_models_dir()
|
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||||
model_dir = models_dir / model_id.normalize()
|
if not await aios.path.exists(model_dir):
|
||||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
return False
|
||||||
|
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||||
deleted = False
|
return True
|
||||||
if await aios.path.exists(model_dir):
|
|
||||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
|
||||||
deleted = True
|
|
||||||
|
|
||||||
# Also clear cache
|
|
||||||
if await aios.path.exists(cache_dir):
|
|
||||||
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
|
|
||||||
|
|
||||||
async def seed_models(seed_dir: str | Path):
|
async def seed_models(seed_dir: str | Path):
|
||||||
@@ -160,28 +151,16 @@ async def fetch_file_list_with_cache(
|
|||||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
await aios.makedirs(target_dir, exist_ok=True)
|
||||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||||
|
if await aios.path.exists(cache_file):
|
||||||
# Always try fresh first
|
async with aiofiles.open(cache_file, "r") as f:
|
||||||
try:
|
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||||
file_list = await fetch_file_list_with_retry(
|
file_list = await fetch_file_list_with_retry(
|
||||||
model_id, revision, recursive=recursive
|
model_id, revision, recursive=recursive
|
||||||
)
|
)
|
||||||
# Update cache with fresh data
|
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||||
async with aiofiles.open(cache_file, "w") as f:
|
async with aiofiles.open(cache_file, "w") as f:
|
||||||
await f.write(
|
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
return file_list
|
||||||
)
|
|
||||||
return file_list
|
|
||||||
except Exception as e:
|
|
||||||
# Fetch failed - try cache fallback
|
|
||||||
if await aios.path.exists(cache_file):
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
|
||||||
)
|
|
||||||
async with aiofiles.open(cache_file, "r") as f:
|
|
||||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
|
||||||
# No cache available, propagate the error
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_file_list_with_retry(
|
async def fetch_file_list_with_retry(
|
||||||
@@ -353,28 +332,8 @@ async def _download_file(
|
|||||||
target_dir: Path,
|
target_dir: Path,
|
||||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
target_path = target_dir / path
|
if await aios.path.exists(target_dir / path):
|
||||||
|
return target_dir / path
|
||||||
if await aios.path.exists(target_path):
|
|
||||||
local_size = (await aios.stat(target_path)).st_size
|
|
||||||
|
|
||||||
# Try to verify against remote, but allow offline operation
|
|
||||||
try:
|
|
||||||
remote_size, _ = await file_meta(model_id, revision, path)
|
|
||||||
if local_size != remote_size:
|
|
||||||
logger.info(
|
|
||||||
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
|
|
||||||
)
|
|
||||||
await aios.remove(target_path)
|
|
||||||
else:
|
|
||||||
return target_path
|
|
||||||
except Exception as e:
|
|
||||||
# Offline or network error - trust local file
|
|
||||||
logger.debug(
|
|
||||||
f"Could not verify {path} against remote (offline?): {e}, using local file"
|
|
||||||
)
|
|
||||||
return target_path
|
|
||||||
|
|
||||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||||
length, etag = await file_meta(model_id, revision, path)
|
length, etag = await file_meta(model_id, revision, path)
|
||||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||||
@@ -583,26 +542,17 @@ async def download_shard(
|
|||||||
async def on_progress_wrapper(
|
async def on_progress_wrapper(
|
||||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
previous_progress = file_progress.get(file.path)
|
start_time = (
|
||||||
|
file_progress[file.path].start_time
|
||||||
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
|
if file.path in file_progress
|
||||||
is_redownload = (
|
else time.time()
|
||||||
previous_progress is not None
|
)
|
||||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
downloaded_this_session = (
|
||||||
|
file_progress[file.path].downloaded_this_session.in_bytes
|
||||||
|
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
|
||||||
|
if file.path in file_progress
|
||||||
|
else curr_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_redownload or previous_progress is None:
|
|
||||||
# Fresh download or re-download: reset tracking
|
|
||||||
start_time = time.time()
|
|
||||||
downloaded_this_session = curr_bytes
|
|
||||||
else:
|
|
||||||
# Continuing download: accumulate
|
|
||||||
start_time = previous_progress.start_time
|
|
||||||
downloaded_this_session = (
|
|
||||||
previous_progress.downloaded_this_session.in_bytes
|
|
||||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
|
||||||
)
|
|
||||||
|
|
||||||
speed = (
|
speed = (
|
||||||
downloaded_this_session / (time.time() - start_time)
|
downloaded_this_session / (time.time() - start_time)
|
||||||
if time.time() - start_time > 0
|
if time.time() - start_time > 0
|
||||||
|
|||||||
@@ -1,451 +0,0 @@
|
|||||||
"""Tests for download verification and cache behavior."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
import aiofiles.os as aios
|
|
||||||
import pytest
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from exo.download.download_utils import (
|
|
||||||
delete_model,
|
|
||||||
fetch_file_list_with_cache,
|
|
||||||
)
|
|
||||||
from exo.shared.types.common import ModelId
|
|
||||||
from exo.shared.types.memory import Memory
|
|
||||||
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_id() -> ModelId:
|
|
||||||
return ModelId("test-org/test-model")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
|
||||||
"""Set up a temporary models directory for testing."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
await aios.makedirs(models_dir, exist_ok=True)
|
|
||||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
|
||||||
yield models_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileVerification:
|
|
||||||
"""Tests for file size verification in _download_file."""
|
|
||||||
|
|
||||||
async def test_redownload_when_file_size_changes_upstream(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that files with mismatched sizes are re-downloaded."""
|
|
||||||
# Import inside test to allow patching
|
|
||||||
from exo.download.download_utils import (
|
|
||||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
|
||||||
)
|
|
||||||
|
|
||||||
target_dir = tmp_path / "downloads"
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create a local file with wrong size
|
|
||||||
local_file = target_dir / "test.safetensors"
|
|
||||||
async with aiofiles.open(local_file, "wb") as f:
|
|
||||||
await f.write(b"local content") # 13 bytes
|
|
||||||
|
|
||||||
remote_size = 1000 # Different from local
|
|
||||||
remote_hash = "abc123"
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.file_meta",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(remote_size, remote_hash),
|
|
||||||
) as mock_file_meta,
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.create_http_session"
|
|
||||||
) as mock_session_factory,
|
|
||||||
):
|
|
||||||
# Set up mock HTTP response for re-download
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status = 200
|
|
||||||
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
|
|
||||||
side_effect=[b"x" * remote_size, b""]
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
|
||||||
return_value=mock_response
|
|
||||||
)
|
|
||||||
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
|
||||||
return_value=None
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
|
||||||
return_value=mock_session
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
|
||||||
return_value=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock calc_hash to return the expected hash
|
|
||||||
with patch(
|
|
||||||
"exo.download.download_utils.calc_hash",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=remote_hash,
|
|
||||||
):
|
|
||||||
await _download_file(model_id, "main", "test.safetensors", target_dir)
|
|
||||||
|
|
||||||
# file_meta should be called twice: once for verification, once for download
|
|
||||||
assert mock_file_meta.call_count == 2
|
|
||||||
|
|
||||||
async def test_skip_download_when_file_size_matches(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that files with matching sizes are not re-downloaded."""
|
|
||||||
from exo.download.download_utils import (
|
|
||||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
|
||||||
)
|
|
||||||
|
|
||||||
target_dir = tmp_path / "downloads"
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create a local file
|
|
||||||
local_file = target_dir / "test.safetensors"
|
|
||||||
local_content = b"local content"
|
|
||||||
async with aiofiles.open(local_file, "wb") as f:
|
|
||||||
await f.write(local_content)
|
|
||||||
|
|
||||||
remote_size = len(local_content) # Same as local
|
|
||||||
remote_hash = "abc123"
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.file_meta",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(remote_size, remote_hash),
|
|
||||||
) as mock_file_meta,
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.create_http_session"
|
|
||||||
) as mock_session_factory,
|
|
||||||
):
|
|
||||||
result = await _download_file(
|
|
||||||
model_id, "main", "test.safetensors", target_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return immediately without downloading
|
|
||||||
assert result == local_file
|
|
||||||
mock_file_meta.assert_called_once()
|
|
||||||
mock_session_factory.assert_not_called()
|
|
||||||
|
|
||||||
async def test_offline_fallback_uses_local_file(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that local files are used when network is unavailable."""
|
|
||||||
from exo.download.download_utils import (
|
|
||||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
|
||||||
)
|
|
||||||
|
|
||||||
target_dir = tmp_path / "downloads"
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create a local file
|
|
||||||
local_file = target_dir / "test.safetensors"
|
|
||||||
async with aiofiles.open(local_file, "wb") as f:
|
|
||||||
await f.write(b"local content")
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.file_meta",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=Exception("Network error"),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.create_http_session"
|
|
||||||
) as mock_session_factory,
|
|
||||||
):
|
|
||||||
result = await _download_file(
|
|
||||||
model_id, "main", "test.safetensors", target_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return local file without attempting download
|
|
||||||
assert result == local_file
|
|
||||||
mock_session_factory.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileListCache:
|
|
||||||
"""Tests for file list caching behavior."""
|
|
||||||
|
|
||||||
async def test_fetch_fresh_and_update_cache(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that fresh data is fetched and cache is updated."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
|
|
||||||
file_list = [
|
|
||||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
|
||||||
FileListEntry(type="file", path="config.json", size=100),
|
|
||||||
]
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=file_list,
|
|
||||||
) as mock_fetch,
|
|
||||||
):
|
|
||||||
result = await fetch_file_list_with_cache(model_id, "main")
|
|
||||||
|
|
||||||
assert result == file_list
|
|
||||||
mock_fetch.assert_called_once()
|
|
||||||
|
|
||||||
# Verify cache was written
|
|
||||||
cache_file = (
|
|
||||||
models_dir
|
|
||||||
/ "caches"
|
|
||||||
/ model_id.normalize()
|
|
||||||
/ f"{model_id.normalize()}--main--file_list.json"
|
|
||||||
)
|
|
||||||
assert await aios.path.exists(cache_file)
|
|
||||||
|
|
||||||
async with aiofiles.open(cache_file, "r") as f:
|
|
||||||
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
|
|
||||||
await f.read()
|
|
||||||
)
|
|
||||||
assert cached_data == file_list
|
|
||||||
|
|
||||||
async def test_fallback_to_cache_when_fetch_fails(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that cached data is used when fetch fails."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
|
||||||
await aios.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create cache file
|
|
||||||
cached_file_list = [
|
|
||||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
|
||||||
]
|
|
||||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
|
||||||
async with aiofiles.open(cache_file, "w") as f:
|
|
||||||
await f.write(
|
|
||||||
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=Exception("Network error"),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await fetch_file_list_with_cache(model_id, "main")
|
|
||||||
|
|
||||||
assert result == cached_file_list
|
|
||||||
|
|
||||||
async def test_error_propagates_when_no_cache(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that errors propagate when fetch fails and no cache exists."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
|
||||||
patch(
|
|
||||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=Exception("Network error"),
|
|
||||||
),
|
|
||||||
pytest.raises(Exception, match="Network error"),
|
|
||||||
):
|
|
||||||
await fetch_file_list_with_cache(model_id, "main")
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelDeletion:
|
|
||||||
"""Tests for model deletion including cache cleanup."""
|
|
||||||
|
|
||||||
async def test_delete_model_clears_cache(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that deleting a model also deletes its cache."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
model_dir = models_dir / model_id.normalize()
|
|
||||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
|
||||||
|
|
||||||
# Create model and cache directories
|
|
||||||
await aios.makedirs(model_dir, exist_ok=True)
|
|
||||||
await aios.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Add some files
|
|
||||||
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
|
|
||||||
await f.write("model data")
|
|
||||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
|
||||||
await f.write("[]")
|
|
||||||
|
|
||||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
|
||||||
result = await delete_model(model_id)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert not await aios.path.exists(model_dir)
|
|
||||||
assert not await aios.path.exists(cache_dir)
|
|
||||||
|
|
||||||
async def test_delete_model_only_cache_exists(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test deleting when only cache exists (model already deleted)."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
|
||||||
|
|
||||||
# Only create cache directory
|
|
||||||
await aios.makedirs(cache_dir, exist_ok=True)
|
|
||||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
|
||||||
await f.write("[]")
|
|
||||||
|
|
||||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
|
||||||
result = await delete_model(model_id)
|
|
||||||
|
|
||||||
# Returns False because model dir didn't exist
|
|
||||||
assert result is False
|
|
||||||
# But cache should still be cleaned up
|
|
||||||
assert not await aios.path.exists(cache_dir)
|
|
||||||
|
|
||||||
async def test_delete_nonexistent_model(
|
|
||||||
self, model_id: ModelId, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test deleting a model that doesn't exist."""
|
|
||||||
models_dir = tmp_path / "models"
|
|
||||||
await aios.makedirs(models_dir, exist_ok=True)
|
|
||||||
|
|
||||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
|
||||||
result = await delete_model(model_id)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestProgressResetOnRedownload:
|
|
||||||
"""Tests for progress tracking when files are re-downloaded."""
|
|
||||||
|
|
||||||
async def test_progress_resets_correctly_on_redownload(
|
|
||||||
self, model_id: ModelId
|
|
||||||
) -> None:
|
|
||||||
"""Test that progress tracking resets when a file is re-downloaded.
|
|
||||||
|
|
||||||
When a file is deleted and re-downloaded (due to size mismatch),
|
|
||||||
the progress tracking should reset rather than calculating negative
|
|
||||||
downloaded_this_session values.
|
|
||||||
"""
|
|
||||||
# Simulate file_progress dict as it exists in download_shard
|
|
||||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
|
||||||
|
|
||||||
# Initialize with old file progress (simulating existing large file)
|
|
||||||
old_file_size = 1_500_000_000 # 1.5 GB
|
|
||||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
|
||||||
repo_id=model_id,
|
|
||||||
repo_revision="main",
|
|
||||||
file_path="model.safetensors",
|
|
||||||
downloaded=Memory.from_bytes(old_file_size),
|
|
||||||
downloaded_this_session=Memory.from_bytes(0),
|
|
||||||
total=Memory.from_bytes(old_file_size),
|
|
||||||
speed=0,
|
|
||||||
eta=timedelta(0),
|
|
||||||
status="not_started",
|
|
||||||
start_time=time.time() - 10, # Started 10 seconds ago
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate the logic from on_progress_wrapper after re-download starts
|
|
||||||
# This is the exact logic from the fixed on_progress_wrapper
|
|
||||||
curr_bytes = 100_000 # 100 KB - new download just started
|
|
||||||
previous_progress = file_progress.get("model.safetensors")
|
|
||||||
|
|
||||||
# Detect re-download: curr_bytes < previous downloaded
|
|
||||||
is_redownload = (
|
|
||||||
previous_progress is not None
|
|
||||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_redownload or previous_progress is None:
|
|
||||||
# Fresh download or re-download: reset tracking
|
|
||||||
start_time = time.time()
|
|
||||||
downloaded_this_session = curr_bytes
|
|
||||||
else:
|
|
||||||
# Continuing download: accumulate
|
|
||||||
start_time = previous_progress.start_time
|
|
||||||
downloaded_this_session = (
|
|
||||||
previous_progress.downloaded_this_session.in_bytes
|
|
||||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Key assertions
|
|
||||||
assert is_redownload is True, "Should detect re-download scenario"
|
|
||||||
assert downloaded_this_session == curr_bytes, (
|
|
||||||
"downloaded_this_session should equal curr_bytes on re-download"
|
|
||||||
)
|
|
||||||
assert downloaded_this_session > 0, (
|
|
||||||
"downloaded_this_session should be positive, not negative"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate speed (should be positive)
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
|
|
||||||
assert speed >= 0, "Speed should be non-negative"
|
|
||||||
|
|
||||||
async def test_progress_accumulates_on_continuing_download(
|
|
||||||
self, model_id: ModelId
|
|
||||||
) -> None:
|
|
||||||
"""Test that progress accumulates correctly for continuing downloads.
|
|
||||||
|
|
||||||
When a download continues from where it left off (resume),
|
|
||||||
the progress should accumulate correctly.
|
|
||||||
"""
|
|
||||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
|
||||||
|
|
||||||
# Initialize with partial download progress
|
|
||||||
initial_downloaded = 500_000 # 500 KB already downloaded
|
|
||||||
start_time = time.time() - 5 # Started 5 seconds ago
|
|
||||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
|
||||||
repo_id=model_id,
|
|
||||||
repo_revision="main",
|
|
||||||
file_path="model.safetensors",
|
|
||||||
downloaded=Memory.from_bytes(initial_downloaded),
|
|
||||||
downloaded_this_session=Memory.from_bytes(initial_downloaded),
|
|
||||||
total=Memory.from_bytes(1_000_000),
|
|
||||||
speed=100_000,
|
|
||||||
eta=timedelta(seconds=5),
|
|
||||||
status="in_progress",
|
|
||||||
start_time=start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progress callback with more bytes downloaded
|
|
||||||
curr_bytes = 600_000 # 600 KB - continuing download
|
|
||||||
previous_progress = file_progress.get("model.safetensors")
|
|
||||||
|
|
||||||
# This is NOT a re-download (curr_bytes > previous downloaded)
|
|
||||||
is_redownload = (
|
|
||||||
previous_progress is not None
|
|
||||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_redownload or previous_progress is None:
|
|
||||||
downloaded_this_session = curr_bytes
|
|
||||||
used_start_time = time.time()
|
|
||||||
else:
|
|
||||||
used_start_time = previous_progress.start_time
|
|
||||||
downloaded_this_session = (
|
|
||||||
previous_progress.downloaded_this_session.in_bytes
|
|
||||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Key assertions
|
|
||||||
assert is_redownload is False, (
|
|
||||||
"Should NOT detect re-download for continuing download"
|
|
||||||
)
|
|
||||||
assert used_start_time == start_time, "Should preserve original start_time"
|
|
||||||
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
|
|
||||||
assert downloaded_this_session == expected_session, (
|
|
||||||
f"Should accumulate: {downloaded_this_session} == {expected_session}"
|
|
||||||
)
|
|
||||||
assert downloaded_this_session == 600_000, (
|
|
||||||
"downloaded_this_session should equal total downloaded so far"
|
|
||||||
)
|
|
||||||
@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||||
"flux1-schnell": ModelCard(
|
"flux1-schnell": ModelCard(
|
||||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||||
n_layers=57,
|
n_layers=57,
|
||||||
hidden_size=1,
|
hidden_size=1,
|
||||||
@@ -428,7 +428,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
storage_size=Memory.from_kb(0),
|
storage_size=Memory.from_kb(0),
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
can_shard=False,
|
can_shard=False,
|
||||||
safetensors_index_filename=None,
|
safetensors_index_filename=None, # Single file
|
||||||
),
|
),
|
||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="text_encoder_2",
|
component_name="text_encoder_2",
|
||||||
@@ -442,7 +442,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
component_name="transformer",
|
component_name="transformer",
|
||||||
component_path="transformer/",
|
component_path="transformer/",
|
||||||
storage_size=Memory.from_bytes(23782357120),
|
storage_size=Memory.from_bytes(23782357120),
|
||||||
n_layers=57,
|
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||||
can_shard=True,
|
can_shard=True,
|
||||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||||
),
|
),
|
||||||
@@ -457,7 +457,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"flux1-dev": ModelCard(
|
"flux1-dev": ModelCard(
|
||||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||||
n_layers=57,
|
n_layers=57,
|
||||||
hidden_size=1,
|
hidden_size=1,
|
||||||
@@ -470,7 +470,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
storage_size=Memory.from_kb(0),
|
storage_size=Memory.from_kb(0),
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
can_shard=False,
|
can_shard=False,
|
||||||
safetensors_index_filename=None,
|
safetensors_index_filename=None, # Single file
|
||||||
),
|
),
|
||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="text_encoder_2",
|
component_name="text_encoder_2",
|
||||||
@@ -484,7 +484,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
component_name="transformer",
|
component_name="transformer",
|
||||||
component_path="transformer/",
|
component_path="transformer/",
|
||||||
storage_size=Memory.from_bytes(23802816640),
|
storage_size=Memory.from_bytes(23802816640),
|
||||||
n_layers=57,
|
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||||
can_shard=True,
|
can_shard=True,
|
||||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||||
),
|
),
|
||||||
@@ -499,7 +499,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"flux1-krea-dev": ModelCard(
|
"flux1-krea-dev": ModelCard(
|
||||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
|
||||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||||
n_layers=57,
|
n_layers=57,
|
||||||
hidden_size=1,
|
hidden_size=1,
|
||||||
@@ -541,9 +541,9 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"qwen-image": ModelCard(
|
"qwen-image": ModelCard(
|
||||||
model_id=ModelId("exolabs/Qwen-Image"),
|
model_id=ModelId("Qwen/Qwen-Image"),
|
||||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||||
n_layers=60,
|
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||||
hidden_size=1,
|
hidden_size=1,
|
||||||
supports_tensor=False,
|
supports_tensor=False,
|
||||||
tasks=[ModelTask.TextToImage],
|
tasks=[ModelTask.TextToImage],
|
||||||
@@ -551,10 +551,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="text_encoder",
|
component_name="text_encoder",
|
||||||
component_path="text_encoder/",
|
component_path="text_encoder/",
|
||||||
storage_size=Memory.from_bytes(16584333312),
|
storage_size=Memory.from_kb(16584333312),
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
can_shard=False,
|
can_shard=False,
|
||||||
safetensors_index_filename=None,
|
safetensors_index_filename=None, # Single file
|
||||||
),
|
),
|
||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="transformer",
|
component_name="transformer",
|
||||||
@@ -575,9 +575,9 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"qwen-image-edit-2509": ModelCard(
|
"qwen-image-edit-2509": ModelCard(
|
||||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||||
n_layers=60,
|
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||||
hidden_size=1,
|
hidden_size=1,
|
||||||
supports_tensor=False,
|
supports_tensor=False,
|
||||||
tasks=[ModelTask.ImageToImage],
|
tasks=[ModelTask.ImageToImage],
|
||||||
@@ -585,10 +585,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="text_encoder",
|
component_name="text_encoder",
|
||||||
component_path="text_encoder/",
|
component_path="text_encoder/",
|
||||||
storage_size=Memory.from_bytes(16584333312),
|
storage_size=Memory.from_kb(16584333312),
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
can_shard=False,
|
can_shard=False,
|
||||||
safetensors_index_filename=None,
|
safetensors_index_filename=None, # Single file
|
||||||
),
|
),
|
||||||
ComponentInfo(
|
ComponentInfo(
|
||||||
component_name="transformer",
|
component_name="transformer",
|
||||||
@@ -610,92 +610,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _generate_image_model_quant_variants(
|
|
||||||
base_name: str,
|
|
||||||
base_card: ModelCard,
|
|
||||||
) -> dict[str, ModelCard]:
|
|
||||||
"""Create quantized variants of an image model card.
|
|
||||||
|
|
||||||
Only the transformer component is quantized; text encoders stay at bf16.
|
|
||||||
Sizes are calculated exactly from the base card's component sizes.
|
|
||||||
"""
|
|
||||||
if base_card.components is None:
|
|
||||||
raise ValueError(f"Image model {base_name} must have components defined")
|
|
||||||
|
|
||||||
# quantizations = [8, 6, 5, 4, 3]
|
|
||||||
quantizations = [8, 4]
|
|
||||||
|
|
||||||
num_transformer_bytes = next(
|
|
||||||
c.storage_size.in_bytes
|
|
||||||
for c in base_card.components
|
|
||||||
if c.component_name == "transformer"
|
|
||||||
)
|
|
||||||
|
|
||||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
|
||||||
|
|
||||||
remaining_bytes = Memory.from_bytes(
|
|
||||||
sum(
|
|
||||||
c.storage_size.in_bytes
|
|
||||||
for c in base_card.components
|
|
||||||
if c.component_name != "transformer"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
|
||||||
assert base_card.components is not None
|
|
||||||
return [
|
|
||||||
ComponentInfo(
|
|
||||||
component_name=c.component_name,
|
|
||||||
component_path=c.component_path,
|
|
||||||
storage_size=new_size
|
|
||||||
if c.component_name == "transformer"
|
|
||||||
else c.storage_size,
|
|
||||||
n_layers=c.n_layers,
|
|
||||||
can_shard=c.can_shard,
|
|
||||||
safetensors_index_filename=c.safetensors_index_filename,
|
|
||||||
)
|
|
||||||
for c in base_card.components
|
|
||||||
]
|
|
||||||
|
|
||||||
variants = {
|
|
||||||
base_name: ModelCard(
|
|
||||||
model_id=base_card.model_id,
|
|
||||||
storage_size=transformer_bytes + remaining_bytes,
|
|
||||||
n_layers=base_card.n_layers,
|
|
||||||
hidden_size=base_card.hidden_size,
|
|
||||||
supports_tensor=base_card.supports_tensor,
|
|
||||||
tasks=base_card.tasks,
|
|
||||||
components=with_transformer_size(transformer_bytes),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
for quant in quantizations:
|
|
||||||
quant_transformer_bytes = Memory.from_bytes(
|
|
||||||
(num_transformer_bytes * quant) // 16
|
|
||||||
)
|
|
||||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
|
||||||
|
|
||||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
|
||||||
|
|
||||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
|
||||||
model_id=model_id,
|
|
||||||
storage_size=total_bytes,
|
|
||||||
n_layers=base_card.n_layers,
|
|
||||||
hidden_size=base_card.hidden_size,
|
|
||||||
supports_tensor=base_card.supports_tensor,
|
|
||||||
tasks=base_card.tasks,
|
|
||||||
components=with_transformer_size(quant_transformer_bytes),
|
|
||||||
)
|
|
||||||
|
|
||||||
return variants
|
|
||||||
|
|
||||||
|
|
||||||
_image_model_cards: dict[str, ModelCard] = {}
|
|
||||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
|
||||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
|
||||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
|
||||||
|
|
||||||
if EXO_ENABLE_IMAGE_MODELS:
|
if EXO_ENABLE_IMAGE_MODELS:
|
||||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""Shared types for MLX-related functionality."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from mlx_lm.models.cache import (
|
|
||||||
KVCache,
|
|
||||||
QuantizedKVCache,
|
|
||||||
RotatingKVCache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This list contains one cache entry per transformer layer
|
|
||||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
|
||||||
@@ -19,8 +19,6 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
|||||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||||
from mlx_lm.models.glm4_moe import MoE
|
from mlx_lm.models.glm4_moe import MoE
|
||||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
|
||||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
|
||||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
from mlx_lm.models.llama import Model as LlamaModel
|
from mlx_lm.models.llama import Model as LlamaModel
|
||||||
@@ -147,10 +145,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
output = mx.distributed.all_gather(output, group=self.group)[
|
|
||||||
-output.shape[0] :
|
|
||||||
] # type :ignore
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -258,6 +252,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||||
|
|
||||||
|
logits = mx.distributed.all_gather(logits, group=group)[
|
||||||
|
-logits.shape[0] :
|
||||||
|
] # type :ignore
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
cls.__call__ = patched_call
|
cls.__call__ = patched_call
|
||||||
@@ -336,7 +334,15 @@ def tensor_auto_parallel(
|
|||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||||
|
try:
|
||||||
|
model.shard(group) # type: ignore
|
||||||
|
return patch_tensor_model(model)
|
||||||
|
except (AttributeError, TypeError, NameError):
|
||||||
|
pass
|
||||||
|
|
||||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||||
|
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||||
group,
|
group,
|
||||||
all_to_sharded_linear,
|
all_to_sharded_linear,
|
||||||
@@ -345,6 +351,7 @@ def tensor_auto_parallel(
|
|||||||
sharded_to_all_linear_in_place,
|
sharded_to_all_linear_in_place,
|
||||||
)
|
)
|
||||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||||
|
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||||
group,
|
group,
|
||||||
all_to_sharded_linear,
|
all_to_sharded_linear,
|
||||||
@@ -360,14 +367,6 @@ def tensor_auto_parallel(
|
|||||||
all_to_sharded_linear_in_place,
|
all_to_sharded_linear_in_place,
|
||||||
sharded_to_all_linear_in_place,
|
sharded_to_all_linear_in_place,
|
||||||
)
|
)
|
||||||
elif isinstance(model, GLM4MoeLiteModel):
|
|
||||||
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
|
|
||||||
group,
|
|
||||||
all_to_sharded_linear,
|
|
||||||
sharded_to_all_linear,
|
|
||||||
all_to_sharded_linear_in_place,
|
|
||||||
sharded_to_all_linear_in_place,
|
|
||||||
)
|
|
||||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||||
group,
|
group,
|
||||||
@@ -442,7 +441,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||||
mx.eval(layer)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -517,8 +516,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||||
layer.mlp.sharding_group = self.group
|
layer.mlp.sharding_group = self.group
|
||||||
|
|
||||||
mx.eval(layer)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -536,84 +533,6 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
|
||||||
def shard_model(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
timeout_seconds: float,
|
|
||||||
on_timeout: TimeoutCallback | None,
|
|
||||||
) -> nn.Module:
|
|
||||||
model = cast(GLM4MoeLiteModel, model)
|
|
||||||
for layer in model.layers: # type: ignore
|
|
||||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
|
||||||
eval_with_timeout(
|
|
||||||
layer.parameters(),
|
|
||||||
timeout_seconds / len(model.layers), # type: ignore
|
|
||||||
on_timeout,
|
|
||||||
)
|
|
||||||
if layer.self_attn.q_lora_rank is None: # type: ignore
|
|
||||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
|
||||||
layer.self_attn.q_proj
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
|
||||||
layer.self_attn.q_b_proj
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
|
||||||
layer.self_attn.num_heads //= self.N
|
|
||||||
|
|
||||||
# Logic from upstream mlx
|
|
||||||
num_heads = layer.self_attn.num_heads
|
|
||||||
sh = self.group.rank() * num_heads
|
|
||||||
eh = sh + num_heads
|
|
||||||
|
|
||||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
|
||||||
return w[sh:eh]
|
|
||||||
|
|
||||||
layer.self_attn.embed_q.apply(shard_heads)
|
|
||||||
layer.self_attn.unembed_out.apply(shard_heads)
|
|
||||||
|
|
||||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
|
||||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
|
||||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
|
||||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
|
||||||
self.all_to_sharded_linear_in_place(
|
|
||||||
layer.mlp.shared_experts.gate_proj
|
|
||||||
)
|
|
||||||
self.sharded_to_all_linear_in_place(
|
|
||||||
layer.mlp.shared_experts.down_proj
|
|
||||||
)
|
|
||||||
self.all_to_sharded_linear_in_place(
|
|
||||||
layer.mlp.shared_experts.up_proj
|
|
||||||
)
|
|
||||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
|
||||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
|
||||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
|
||||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
|
||||||
layer.mlp.sharding_group = self.group # type: ignore
|
|
||||||
mx.eval(layer)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
|
||||||
def __init__(self, layer: _LayerCallable):
|
|
||||||
super().__init__(layer)
|
|
||||||
self.sharding_group: mx.distributed.Group | None = None
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
if self.sharding_group is not None:
|
|
||||||
x = sum_gradients(self.sharding_group)(x)
|
|
||||||
y = self.original_layer.__call__(x)
|
|
||||||
if self.sharding_group is not None:
|
|
||||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||||
def shard_model(
|
def shard_model(
|
||||||
self,
|
self,
|
||||||
@@ -622,7 +541,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
on_timeout: TimeoutCallback | None,
|
on_timeout: TimeoutCallback | None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
model = cast(MiniMaxModel, model)
|
model = cast(MiniMaxModel, model)
|
||||||
rank = self.group.rank()
|
|
||||||
for layer in model.layers:
|
for layer in model.layers:
|
||||||
eval_with_timeout(
|
eval_with_timeout(
|
||||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||||
@@ -632,16 +550,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||||
|
|
||||||
# Shard qk_norm weights if present (must match sharded head count)
|
|
||||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
|
||||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
|
||||||
self.N, axis=-1
|
|
||||||
)[rank]
|
|
||||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
|
||||||
self.N, axis=-1
|
|
||||||
)[rank]
|
|
||||||
|
|
||||||
layer.self_attn.num_attention_heads //= self.N
|
layer.self_attn.num_attention_heads //= self.N
|
||||||
layer.self_attn.num_key_value_heads //= self.N
|
layer.self_attn.num_key_value_heads //= self.N
|
||||||
|
|
||||||
@@ -658,7 +566,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
)
|
)
|
||||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||||
mx.eval(layer)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -699,7 +607,6 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||||
|
|
||||||
mx.eval(layer)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -754,7 +661,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
|||||||
|
|
||||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||||
mx.eval(layer)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,81 +1,39 @@
|
|||||||
import os
|
# type: ignore
|
||||||
|
# TODO: Fix this file, including types!
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, cast
|
from typing import Callable
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.models.cache import (
|
from mlx_lm import stream_generate
|
||||||
KVCache,
|
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||||
QuantizedKVCache,
|
|
||||||
RotatingKVCache,
|
|
||||||
trim_prompt_cache,
|
|
||||||
)
|
|
||||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
from exo.shared.types.mlx import KVCacheType
|
|
||||||
from exo.worker.engines.mlx import Model
|
from exo.worker.engines.mlx import Model
|
||||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||||
from exo.worker.runner.bootstrap import logger
|
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||||
|
|
||||||
# Fraction of device memory above which LRU eviction kicks in
|
|
||||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
|
||||||
_MEMORY_THRESHOLD = float(
|
|
||||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class KVPrefixCache:
|
class KVPrefixCache:
|
||||||
def __init__(self, tokenizer: TokenizerWrapper):
|
def __init__(self):
|
||||||
|
# Only one prefix cache per runner.
|
||||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||||
self.caches: list[KVCacheType] = []
|
self.caches: list[list[_BaseCache]] = []
|
||||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
|
||||||
self._access_counter: int = 0
|
|
||||||
self._tokenizer: TokenizerWrapper = tokenizer
|
|
||||||
|
|
||||||
def clear(self):
|
def add_kv_cache(
|
||||||
"""Clear all cached prompts and caches."""
|
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||||
self.prompts.clear()
|
):
|
||||||
self.caches.clear()
|
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||||
self._last_used.clear()
|
|
||||||
|
|
||||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
|
||||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
|
||||||
self._evict_if_needed()
|
|
||||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
|
||||||
self.prompts.append(tokenized_prompt)
|
self.prompts.append(tokenized_prompt)
|
||||||
self.caches.append(deepcopy(cache))
|
self.caches.append(deepcopy(cache))
|
||||||
self._access_counter += 1
|
|
||||||
self._last_used.append(self._access_counter)
|
|
||||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
|
||||||
|
|
||||||
def update_kv_cache(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
prompt: str,
|
|
||||||
cache: KVCacheType,
|
|
||||||
):
|
|
||||||
"""Update an existing cache entry in-place."""
|
|
||||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
|
||||||
self.prompts[index] = tokenized_prompt
|
|
||||||
self.caches[index] = deepcopy(cache)
|
|
||||||
self._access_counter += 1
|
|
||||||
self._last_used[index] = self._access_counter
|
|
||||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
|
||||||
|
|
||||||
def get_kv_cache(
|
def get_kv_cache(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
sampler: Callable[[mx.array], mx.array],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
) -> list[_BaseCache]:
|
||||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (cache, remaining_tokens, matched_index) where:
|
|
||||||
- cache: KV cache to use for generation
|
|
||||||
- remaining_tokens: tokens that still need prefilling
|
|
||||||
- matched_index: index of the matched entry (None if no match)
|
|
||||||
"""
|
|
||||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
|
||||||
max_length = len(tokenized_prompt)
|
max_length = len(tokenized_prompt)
|
||||||
|
|
||||||
best_snapshot_index, best_snapshot_length = None, 0
|
best_snapshot_index, best_snapshot_length = None, 0
|
||||||
@@ -84,127 +42,63 @@ class KVPrefixCache:
|
|||||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||||
|
|
||||||
if length == max_length:
|
if length == max_length:
|
||||||
# Exact match - cached prompt starts with our entire prompt
|
return self.caches[i]
|
||||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
|
||||||
prompt_cache = deepcopy(self.caches[i])
|
|
||||||
cached_length = _cache_length(self.caches[i])
|
|
||||||
tokens_to_trim = cached_length - (max_length - 1)
|
|
||||||
if tokens_to_trim > 0:
|
|
||||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
|
||||||
self._access_counter += 1
|
|
||||||
self._last_used[i] = self._access_counter
|
|
||||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
|
||||||
return prompt_cache, tokenized_prompt[-1:], i
|
|
||||||
|
|
||||||
if length > best_snapshot_length:
|
if length > best_snapshot_length:
|
||||||
best_snapshot_index, best_snapshot_length = i, length
|
best_snapshot_index, best_snapshot_length = i, length
|
||||||
|
|
||||||
if best_snapshot_index is not None:
|
if best_snapshot_index is not None:
|
||||||
new_tokens = max_length - best_snapshot_length
|
|
||||||
logger.info(
|
|
||||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
|
||||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||||
|
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
|
||||||
tokens_to_trim = cached_length - best_snapshot_length
|
|
||||||
if tokens_to_trim > 0:
|
|
||||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
|
||||||
|
|
||||||
self._access_counter += 1
|
|
||||||
self._last_used[best_snapshot_index] = self._access_counter
|
|
||||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
|
||||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prompt_cache = make_kv_cache(model)
|
prompt_cache = make_kv_cache(
|
||||||
if len(self.prompts) == 0:
|
model,
|
||||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
# max_kv_size=MAX_KV_SIZE,
|
||||||
else:
|
# keep=KEEP_KV_SIZE
|
||||||
logger.info(
|
|
||||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt_cache, tokenized_prompt, None
|
|
||||||
|
|
||||||
def _evict_if_needed(self):
|
|
||||||
"""Evict least recently used entries while memory pressure is high."""
|
|
||||||
if len(self.caches) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
active: int = mx.metal.get_active_memory()
|
|
||||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
|
||||||
if active < limit * _MEMORY_THRESHOLD:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Evict LRU entries until below threshold or only one entry left
|
|
||||||
while len(self.caches) > 0:
|
|
||||||
lru_index = self._last_used.index(min(self._last_used))
|
|
||||||
evicted_tokens = len(self.prompts[lru_index])
|
|
||||||
self.prompts.pop(lru_index)
|
|
||||||
self.caches.pop(lru_index)
|
|
||||||
self._last_used.pop(lru_index)
|
|
||||||
logger.info(
|
|
||||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
active = mx.metal.get_active_memory()
|
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||||
if active < limit * _MEMORY_THRESHOLD:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
return prompt_cache
|
||||||
|
|
||||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||||
"""Encode a prompt string to token array.
|
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||||
|
tokenizer.bos_token
|
||||||
For chat-templated prompts (which have their own structure markers like
|
)
|
||||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
tokenized_prompt = tokenizer.encode(
|
||||||
that would corrupt the prompt structure.
|
prompt, add_special_tokens=add_special_tokens
|
||||||
"""
|
)
|
||||||
# Chat templates define their own structure - don't add BOS/EOS
|
return mx.array(tokenized_prompt)
|
||||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
|
||||||
return mx.array(tokenized_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def _cache_length(cache: KVCacheType) -> int:
|
|
||||||
"""Get the number of tokens in a KV cache."""
|
|
||||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
|
||||||
return max(c.offset for c in cache) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||||
"""Find the length of the common prefix between two token arrays."""
|
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||||
return int(mx.sum(prefix_mask).item())
|
return int(mx.sum(prefix_mask).item())
|
||||||
|
|
||||||
|
|
||||||
def make_kv_cache(
|
def prefill(
|
||||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
model: Model,
|
||||||
) -> KVCacheType:
|
tokenizer: TokenizerWrapper,
|
||||||
assert hasattr(model, "layers")
|
sampler: Callable[[mx.array], mx.array],
|
||||||
|
prompt: mx.array,
|
||||||
# TODO: Do this for all models
|
cache: list[_BaseCache],
|
||||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
) -> None:
|
||||||
logger.info("Using MLX LM's make cache")
|
for _ in stream_generate(
|
||||||
return model.make_cache() # type: ignore
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
if max_kv_size is None:
|
prompt=prompt,
|
||||||
if KV_CACHE_BITS is None:
|
max_tokens=0,
|
||||||
logger.info("Using default KV cache")
|
sampler=sampler,
|
||||||
return [KVCache() for _ in model.layers]
|
prompt_cache=cache,
|
||||||
else:
|
prefill_step_size=2048,
|
||||||
logger.info("Using quantized KV cache")
|
kv_group_size=KV_GROUP_SIZE,
|
||||||
return [
|
kv_bits=KV_BITS,
|
||||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
):
|
||||||
for _ in model.layers
|
pass
|
||||||
]
|
|
||||||
else:
|
|
||||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
|
||||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
KV_GROUP_SIZE: int | None = 32
|
KV_GROUP_SIZE: int | None = 32
|
||||||
KV_BITS: int | None = None
|
KV_BITS: int | None = None
|
||||||
ATTENTION_KV_BITS: int | None = 4
|
ATTENTION_KV_BITS: int | None = 4
|
||||||
MAX_TOKENS: int = 32168
|
MAX_TOKENS: int = 8192
|
||||||
MAX_KV_SIZE: int | None = 3200
|
MAX_KV_SIZE: int | None = 3200
|
||||||
KEEP_KV_SIZE: int | None = 1600
|
KEEP_KV_SIZE: int | None = 1600
|
||||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import time
|
|
||||||
from typing import Any, Callable, Generator, cast, get_args
|
from typing import Any, Callable, Generator, cast, get_args
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.generate import stream_generate
|
from mlx_lm.generate import stream_generate
|
||||||
from mlx_lm.models.cache import trim_prompt_cache
|
from mlx_lm.models.cache import KVCache
|
||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
# from exo.engines.mlx.cache import KVPrefixCache
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
BenchChatCompletionTaskParams,
|
BenchChatCompletionTaskParams,
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
@@ -14,78 +14,35 @@ from exo.shared.types.api import (
|
|||||||
GenerationStats,
|
GenerationStats,
|
||||||
)
|
)
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.mlx import KVCacheType
|
|
||||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||||
from exo.shared.types.worker.runner_response import (
|
from exo.shared.types.worker.runner_response import (
|
||||||
GenerationResponse,
|
GenerationResponse,
|
||||||
)
|
)
|
||||||
from exo.worker.engines.mlx import Model
|
from exo.worker.engines.mlx import Model
|
||||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
|
||||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||||
from exo.worker.engines.mlx.utils_mlx import (
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
|
make_kv_cache,
|
||||||
mx_barrier,
|
mx_barrier,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.bootstrap import logger
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
generation_stream = mx.new_stream(mx.default_device())
|
generation_stream = mx.new_stream(mx.default_device())
|
||||||
|
|
||||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
|
||||||
|
|
||||||
|
def maybe_quantize_kv_cache(
|
||||||
def prefill(
|
prompt_cache: list[KVCache | Any],
|
||||||
model: Model,
|
quantized_kv_start: int,
|
||||||
tokenizer: TokenizerWrapper,
|
kv_group_size: int,
|
||||||
sampler: Callable[[mx.array], mx.array],
|
kv_bits: int | None,
|
||||||
prompt_tokens: mx.array,
|
) -> None:
|
||||||
cache: KVCacheType,
|
if kv_bits is None:
|
||||||
) -> float:
|
return
|
||||||
"""Prefill the KV cache with prompt tokens.
|
for e, c in enumerate(prompt_cache):
|
||||||
|
if (
|
||||||
This runs the model over the prompt tokens to populate the cache,
|
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||||
then trims off the extra generated token.
|
):
|
||||||
|
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||||
Returns:
|
|
||||||
tokens_per_sec
|
|
||||||
"""
|
|
||||||
num_tokens = len(prompt_tokens)
|
|
||||||
if num_tokens == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
def progress_callback(processed: int, total: int) -> None:
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
|
||||||
logger.debug(
|
|
||||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
|
||||||
# We just throw away the generated token - we only care about filling the cache
|
|
||||||
for _ in stream_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
prompt=prompt_tokens,
|
|
||||||
max_tokens=1,
|
|
||||||
sampler=sampler,
|
|
||||||
prompt_cache=cache,
|
|
||||||
prefill_step_size=2048,
|
|
||||||
kv_group_size=KV_GROUP_SIZE,
|
|
||||||
kv_bits=KV_BITS,
|
|
||||||
prompt_progress_callback=progress_callback,
|
|
||||||
):
|
|
||||||
break # Stop after first iteration - cache is now filled
|
|
||||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
|
||||||
|
|
||||||
elapsed = time.perf_counter() - start_time
|
|
||||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
|
||||||
logger.debug(
|
|
||||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
|
||||||
f"({tokens_per_sec:.1f} tok/s)"
|
|
||||||
)
|
|
||||||
return tokens_per_sec
|
|
||||||
|
|
||||||
|
|
||||||
def warmup_inference(
|
def warmup_inference(
|
||||||
@@ -163,7 +120,6 @@ def mlx_generate(
|
|||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerWrapper,
|
||||||
task: ChatCompletionTaskParams,
|
task: ChatCompletionTaskParams,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
kv_prefix_cache: KVPrefixCache | None = None,
|
|
||||||
) -> Generator[GenerationResponse]:
|
) -> Generator[GenerationResponse]:
|
||||||
# Ensure that generation stats only contains peak memory for this generation
|
# Ensure that generation stats only contains peak memory for this generation
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
@@ -175,22 +131,7 @@ def mlx_generate(
|
|||||||
if task.seed is not None:
|
if task.seed is not None:
|
||||||
mx.random.seed(task.seed)
|
mx.random.seed(task.seed)
|
||||||
|
|
||||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
caches = make_kv_cache(model=model)
|
||||||
if is_bench:
|
|
||||||
kv_prefix_cache = None
|
|
||||||
|
|
||||||
# Use prefix cache if available, otherwise create fresh cache
|
|
||||||
prefix_hit_length = 0
|
|
||||||
matched_index: int | None = None
|
|
||||||
if kv_prefix_cache is None:
|
|
||||||
caches = make_kv_cache(model=model)
|
|
||||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
else:
|
|
||||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
|
||||||
model, prompt
|
|
||||||
)
|
|
||||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
|
||||||
|
|
||||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||||
if is_bench:
|
if is_bench:
|
||||||
@@ -203,19 +144,11 @@ def mlx_generate(
|
|||||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill cache with all tokens except the last one
|
|
||||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
|
||||||
|
|
||||||
# stream_generate starts from the last token
|
|
||||||
last_token = prompt_tokens[-1:]
|
|
||||||
|
|
||||||
max_tokens = task.max_tokens or MAX_TOKENS
|
max_tokens = task.max_tokens or MAX_TOKENS
|
||||||
generated_text_parts: list[str] = []
|
|
||||||
generation_start_time = time.perf_counter()
|
|
||||||
for out in stream_generate(
|
for out in stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompt=last_token,
|
prompt=prompt,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
@@ -225,13 +158,12 @@ def mlx_generate(
|
|||||||
kv_group_size=KV_GROUP_SIZE,
|
kv_group_size=KV_GROUP_SIZE,
|
||||||
kv_bits=KV_BITS,
|
kv_bits=KV_BITS,
|
||||||
):
|
):
|
||||||
generated_text_parts.append(out.text)
|
|
||||||
logger.info(out.text)
|
logger.info(out.text)
|
||||||
|
|
||||||
stats: GenerationStats | None = None
|
stats: GenerationStats | None = None
|
||||||
if out.finish_reason is not None:
|
if out.finish_reason is not None:
|
||||||
stats = GenerationStats(
|
stats = GenerationStats(
|
||||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
prompt_tps=float(out.prompt_tps),
|
||||||
generation_tps=float(out.generation_tps),
|
generation_tps=float(out.generation_tps),
|
||||||
prompt_tokens=int(out.prompt_tokens),
|
prompt_tokens=int(out.prompt_tokens),
|
||||||
generation_tokens=int(out.generation_tokens),
|
generation_tokens=int(out.generation_tokens),
|
||||||
@@ -253,26 +185,6 @@ def mlx_generate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if out.finish_reason is not None:
|
if out.finish_reason is not None:
|
||||||
# Log generation stats
|
|
||||||
generation_elapsed = time.perf_counter() - generation_start_time
|
|
||||||
generated_tokens = len(generated_text_parts)
|
|
||||||
generation_tps = (
|
|
||||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
|
||||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
|
||||||
f"{generation_tps:.1f} tok/s"
|
|
||||||
)
|
|
||||||
if kv_prefix_cache is not None:
|
|
||||||
full_prompt = prompt + "".join(generated_text_parts)
|
|
||||||
if (
|
|
||||||
matched_index is not None
|
|
||||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
|
||||||
):
|
|
||||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
|
||||||
else:
|
|
||||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO: Do we want an mx_barrier?
|
# TODO: Do we want an mx_barrier?
|
||||||
|
|||||||
@@ -18,12 +18,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||||
|
|
||||||
from mlx_lm.models.cache import KVCache
|
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||||
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
from exo.shared.models.model_cards import ModelId
|
from exo.shared.models.model_cards import ModelId
|
||||||
from exo.worker.engines.mlx.constants import (
|
from exo.worker.engines.mlx.constants import (
|
||||||
|
CACHE_GROUP_SIZE,
|
||||||
|
KV_CACHE_BITS,
|
||||||
TRUST_REMOTE_CODE,
|
TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -402,11 +405,7 @@ def apply_chat_template(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
message.content = "\n".join(c.text for c in message.content).strip()
|
message.content = "\n".join(c.text for c in message.content).strip()
|
||||||
if (
|
if message.content is None and message.thinking is None:
|
||||||
message.content is None
|
|
||||||
and message.thinking is None
|
|
||||||
and message.tool_calls is None
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Null values are not valid when applying templates in tokenizer
|
# Null values are not valid when applying templates in tokenizer
|
||||||
@@ -463,6 +462,31 @@ class NullKVCache(KVCache):
|
|||||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||||
|
|
||||||
|
|
||||||
|
def make_kv_cache(
|
||||||
|
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||||
|
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||||
|
assert hasattr(model, "layers")
|
||||||
|
|
||||||
|
# TODO: Do this for all models
|
||||||
|
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||||
|
logger.info("Using MLX LM's make cache")
|
||||||
|
return model.make_cache() # type: ignore
|
||||||
|
|
||||||
|
if max_kv_size is None:
|
||||||
|
if KV_CACHE_BITS is None:
|
||||||
|
logger.info("Using default KV cache")
|
||||||
|
return [KVCache() for _ in model.layers]
|
||||||
|
else:
|
||||||
|
logger.info("Using quantized KV cache")
|
||||||
|
return [
|
||||||
|
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||||
|
for _ in model.layers
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||||
|
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||||
|
|
||||||
|
|
||||||
def mlx_force_oom(size: int = 40000) -> None:
|
def mlx_force_oom(size: int = 40000) -> None:
|
||||||
"""
|
"""
|
||||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ from exo.worker.engines.image import (
|
|||||||
warmup_image_generator,
|
warmup_image_generator,
|
||||||
)
|
)
|
||||||
from exo.worker.engines.mlx import Model
|
from exo.worker.engines.mlx import Model
|
||||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
|
||||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||||
from exo.worker.engines.mlx.utils_mlx import (
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
@@ -104,7 +103,6 @@ def main(
|
|||||||
model: Model | DistributedImageModel | None = None
|
model: Model | DistributedImageModel | None = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
group = None
|
group = None
|
||||||
kv_prefix_cache: KVPrefixCache | None = None
|
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
current_status: RunnerStatus = RunnerIdle()
|
||||||
logger.info("runner created")
|
logger.info("runner created")
|
||||||
@@ -163,8 +161,6 @@ def main(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||||
)
|
)
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||||
@@ -174,6 +170,7 @@ def main(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||||
)
|
)
|
||||||
|
|
||||||
current_status = RunnerLoaded()
|
current_status = RunnerLoaded()
|
||||||
logger.info("runner loaded")
|
logger.info("runner loaded")
|
||||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||||
@@ -241,7 +238,6 @@ def main(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
task=task_params,
|
task=task_params,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# For other thinking models (GLM, etc.), check if we need to
|
# For other thinking models (GLM, etc.), check if we need to
|
||||||
|
|||||||
@@ -1,545 +0,0 @@
|
|||||||
# type: ignore
|
|
||||||
import time
|
|
||||||
from typing import cast
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import pytest
|
|
||||||
from mlx_lm.models.cache import KVCache
|
|
||||||
from mlx_lm.sample_utils import make_sampler
|
|
||||||
|
|
||||||
from exo.shared.types.api import ChatCompletionMessage
|
|
||||||
from exo.shared.types.common import ModelId
|
|
||||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
|
||||||
from exo.worker.engines.mlx import Model
|
|
||||||
from exo.worker.engines.mlx.cache import (
|
|
||||||
KVPrefixCache,
|
|
||||||
_cache_length,
|
|
||||||
_get_prefix_length,
|
|
||||||
encode_prompt,
|
|
||||||
make_kv_cache,
|
|
||||||
)
|
|
||||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
|
||||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
|
||||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
|
||||||
DEFAULT_GPT_OSS_CONFIG,
|
|
||||||
DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_model_exists() -> bool:
|
|
||||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetPrefixLength:
|
|
||||||
def test_identical_arrays(self):
|
|
||||||
a = mx.array([1, 2, 3, 4, 5])
|
|
||||||
b = mx.array([1, 2, 3, 4, 5])
|
|
||||||
assert _get_prefix_length(a, b) == 5
|
|
||||||
|
|
||||||
def test_no_common_prefix(self):
|
|
||||||
a = mx.array([1, 2, 3])
|
|
||||||
b = mx.array([4, 5, 6])
|
|
||||||
assert _get_prefix_length(a, b) == 0
|
|
||||||
|
|
||||||
def test_partial_prefix(self):
|
|
||||||
a = mx.array([1, 2, 3, 4, 5])
|
|
||||||
b = mx.array([1, 2, 3, 7, 8])
|
|
||||||
assert _get_prefix_length(a, b) == 3
|
|
||||||
|
|
||||||
def test_prompt_longer_than_cached(self):
|
|
||||||
a = mx.array([1, 2, 3, 4, 5])
|
|
||||||
b = mx.array([1, 2, 3])
|
|
||||||
assert _get_prefix_length(a, b) == 3
|
|
||||||
|
|
||||||
def test_cached_longer_than_prompt(self):
|
|
||||||
a = mx.array([1, 2, 3])
|
|
||||||
b = mx.array([1, 2, 3, 4, 5])
|
|
||||||
assert _get_prefix_length(a, b) == 3
|
|
||||||
|
|
||||||
def test_single_token_match(self):
|
|
||||||
a = mx.array([1, 2, 3])
|
|
||||||
b = mx.array([1, 5, 6])
|
|
||||||
assert _get_prefix_length(a, b) == 1
|
|
||||||
|
|
||||||
def test_empty_prompt(self):
|
|
||||||
a = mx.array([]).astype(mx.int32)
|
|
||||||
b = mx.array([1, 2, 3])
|
|
||||||
assert _get_prefix_length(a, b) == 0
|
|
||||||
|
|
||||||
def test_empty_cached(self):
|
|
||||||
a = mx.array([1, 2, 3])
|
|
||||||
b = mx.array([]).astype(mx.int32)
|
|
||||||
assert _get_prefix_length(a, b) == 0
|
|
||||||
|
|
||||||
def test_both_empty(self):
|
|
||||||
a = mx.array([]).astype(mx.int32)
|
|
||||||
b = mx.array([]).astype(mx.int32)
|
|
||||||
assert _get_prefix_length(a, b) == 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestKVPrefix:
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_tokenizer(self):
|
|
||||||
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
tokenizer = MagicMock()
|
|
||||||
tokenizer.encode.return_value = [1, 2, 3]
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def test_starts_empty(self, mock_tokenizer):
|
|
||||||
cache = KVPrefixCache(mock_tokenizer)
|
|
||||||
assert len(cache.prompts) == 0
|
|
||||||
assert len(cache.caches) == 0
|
|
||||||
|
|
||||||
def test_clear_empties_cache(self, mock_tokenizer):
|
|
||||||
cache = KVPrefixCache(mock_tokenizer)
|
|
||||||
cache.prompts.append(mx.array([1, 2, 3]))
|
|
||||||
cache.caches.append([KVCache()])
|
|
||||||
cache.clear()
|
|
||||||
assert len(cache.prompts) == 0
|
|
||||||
assert len(cache.caches) == 0
|
|
||||||
|
|
||||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
|
||||||
cache = KVPrefixCache(mock_tokenizer)
|
|
||||||
cache.clear()
|
|
||||||
assert len(cache.prompts) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def _load_gpt_oss() -> tuple[Model, object]:
|
|
||||||
from mlx_lm.utils import load_model
|
|
||||||
|
|
||||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
|
||||||
|
|
||||||
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
|
|
||||||
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
|
|
||||||
|
|
||||||
model, _ = load_model(model_path, lazy=False)
|
|
||||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
|
||||||
return cast(Model, model), tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not _check_model_exists(),
|
|
||||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
|
||||||
)
|
|
||||||
class TestKVPrefixCacheWithModel:
|
|
||||||
@pytest.fixture(scope="class")
|
|
||||||
def model_and_tokenizer(self):
|
|
||||||
model, tokenizer = _load_gpt_oss()
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def test_prefill_populates_cache(self, model_and_tokenizer):
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
|
|
||||||
# Cache should now hold the prompt tokens
|
|
||||||
assert _cache_length(cache) == len(tokens)
|
|
||||||
|
|
||||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Test exact")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
|
||||||
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
assert stored_length > 0
|
|
||||||
|
|
||||||
# Retrieve with same prompt: exact match
|
|
||||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
|
||||||
model, prompt
|
|
||||||
)
|
|
||||||
assert matched_index == 0
|
|
||||||
|
|
||||||
# Exact match returns only last token
|
|
||||||
assert len(remaining_tokens) == 1
|
|
||||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
|
||||||
|
|
||||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
|
||||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
short_task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Hi")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
short_prompt = apply_chat_template(tokenizer, short_task)
|
|
||||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
|
||||||
|
|
||||||
# Query with longer prompt that shares the chat template prefix
|
|
||||||
long_task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[
|
|
||||||
ChatCompletionMessage(role="user", content="Hi there, how are you?")
|
|
||||||
],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
long_prompt = apply_chat_template(tokenizer, long_task)
|
|
||||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
|
||||||
|
|
||||||
# The prompts share a prefix (chat template preamble + "Hi")
|
|
||||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
|
||||||
assert expected_prefix > 0, (
|
|
||||||
"Prompts should share a prefix from the chat template"
|
|
||||||
)
|
|
||||||
|
|
||||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
|
||||||
model, long_prompt
|
|
||||||
)
|
|
||||||
assert matched_index == 0
|
|
||||||
|
|
||||||
# remaining_tokens should be the suffix after the shared prefix
|
|
||||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
|
||||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
|
||||||
|
|
||||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
|
||||||
self, model_and_tokenizer
|
|
||||||
):
|
|
||||||
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
|
||||||
|
|
||||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
|
|
||||||
# Get cache and mutate it (simulating what generation does)
|
|
||||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
|
||||||
assert matched_index == 0
|
|
||||||
|
|
||||||
# Simulate generation: feed many additional tokens through the cache
|
|
||||||
head_dim = result_cache[0].keys.shape[-1]
|
|
||||||
num_heads = result_cache[0].keys.shape[1]
|
|
||||||
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
|
|
||||||
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
|
|
||||||
for layer_cache in result_cache:
|
|
||||||
layer_cache.update_and_fetch(extra_keys, extra_values)
|
|
||||||
mx.eval([c.keys for c in result_cache])
|
|
||||||
|
|
||||||
# Stored cache must be unchanged
|
|
||||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
|
||||||
|
|
||||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
|
||||||
self, model_and_tokenizer
|
|
||||||
):
|
|
||||||
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
|
||||||
|
|
||||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
|
||||||
|
|
||||||
head_dim = result_cache[0].keys.shape[-1]
|
|
||||||
num_heads = result_cache[0].keys.shape[1]
|
|
||||||
extra = mx.random.normal((1, num_heads, 30, head_dim))
|
|
||||||
for layer_cache in result_cache:
|
|
||||||
layer_cache.update_and_fetch(extra, extra)
|
|
||||||
mx.eval([c.keys for c in result_cache])
|
|
||||||
|
|
||||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
|
||||||
f"Failed on loop {i}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
|
|
||||||
"""mlx_generate should save the cache after generation completes."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Hello")],
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
|
|
||||||
# Consume the entire generator so the cache-saving code after yield runs
|
|
||||||
generated_tokens = 0
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task,
|
|
||||||
prompt=prompt,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
generated_tokens += 1
|
|
||||||
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
assert len(kv_prefix_cache.caches) == 1
|
|
||||||
# Cache should contain prompt + generated tokens
|
|
||||||
expected_length = len(prompt_tokens) + generated_tokens
|
|
||||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
|
||||||
|
|
||||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
|
||||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
|
|
||||||
# First generation populates cache
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task,
|
|
||||||
prompt=prompt,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
|
|
||||||
# Second call should find a prefix match (the stored cache contains
|
|
||||||
# prompt + generated tokens, which shares the prompt prefix)
|
|
||||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
|
||||||
model, prompt
|
|
||||||
)
|
|
||||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
|
||||||
# so this is a prefix match where our prompt is fully contained
|
|
||||||
assert matched_index == 0
|
|
||||||
# Exact match: remaining_tokens is just the last token
|
|
||||||
assert len(remaining_tokens) == 1
|
|
||||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
|
||||||
|
|
||||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
|
||||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
|
|
||||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
|
||||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
|
||||||
base_tokens = tokenizer.encode(base_text)
|
|
||||||
repeats = (1200 // len(base_tokens)) + 2
|
|
||||||
long_content = base_text * repeats
|
|
||||||
|
|
||||||
task1 = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content=long_content)],
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
prompt1 = apply_chat_template(tokenizer, task1)
|
|
||||||
prompt1_tokens = encode_prompt(tokenizer, prompt1)
|
|
||||||
assert len(prompt1_tokens) > 1000, (
|
|
||||||
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
|
|
||||||
)
|
|
||||||
|
|
||||||
# First generation populates the cache (must prefill all tokens)
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task1,
|
|
||||||
prompt=prompt1,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
first_gen_time = time.perf_counter() - t0
|
|
||||||
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
|
|
||||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
|
||||||
task2 = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[
|
|
||||||
ChatCompletionMessage(role="user", content=long_content),
|
|
||||||
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
|
|
||||||
ChatCompletionMessage(role="user", content="Tell me more."),
|
|
||||||
],
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
prompt2 = apply_chat_template(tokenizer, task2)
|
|
||||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
|
||||||
|
|
||||||
# Verify the prompts share a long prefix
|
|
||||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
|
||||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
|
||||||
|
|
||||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task2,
|
|
||||||
prompt=prompt2,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
second_gen_time = time.perf_counter() - t0
|
|
||||||
|
|
||||||
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
|
|
||||||
assert second_gen_time < first_gen_time * 0.5, (
|
|
||||||
f"Expected prefix cache speedup: "
|
|
||||||
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
|
||||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
assert updated_cache_length > first_cache_length
|
|
||||||
|
|
||||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
|
||||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
|
|
||||||
# First generation populates cache
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task,
|
|
||||||
prompt=prompt,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
|
||||||
|
|
||||||
# Second generation gets the cache and mutates it during generation
|
|
||||||
for _response in mlx_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
task=task,
|
|
||||||
prompt=prompt,
|
|
||||||
kv_prefix_cache=kv_prefix_cache,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# The first stored cache must not have been mutated by the second generation
|
|
||||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
|
||||||
|
|
||||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
|
||||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
|
||||||
model, tokenizer = model_and_tokenizer
|
|
||||||
|
|
||||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
|
||||||
|
|
||||||
# Add three cache entries with different prompts
|
|
||||||
prompts = ["First entry", "Second entry", "Third entry"]
|
|
||||||
for i, content in enumerate(prompts):
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content=content)],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
|
||||||
# Stagger _last_used so LRU order is deterministic
|
|
||||||
kv_prefix_cache._last_used[i] = float(i)
|
|
||||||
|
|
||||||
assert len(kv_prefix_cache.prompts) == 3
|
|
||||||
|
|
||||||
# Access the third entry to make it most recently used
|
|
||||||
kv_prefix_cache._last_used[2] = 100.0
|
|
||||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
|
||||||
|
|
||||||
# Simulate memory pressure: active memory exceeds threshold
|
|
||||||
fake_limit = 1000
|
|
||||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
|
||||||
return_value=fake_active,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
|
||||||
return_value={"max_recommended_working_set_size": fake_limit},
|
|
||||||
),
|
|
||||||
):
|
|
||||||
# Trigger eviction by adding a new entry
|
|
||||||
task = ChatCompletionTaskParams(
|
|
||||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
|
||||||
messages=[ChatCompletionMessage(role="user", content="New entry")],
|
|
||||||
max_tokens=1,
|
|
||||||
)
|
|
||||||
prompt = apply_chat_template(tokenizer, task)
|
|
||||||
tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
cache = make_kv_cache(model)
|
|
||||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
|
||||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
|
||||||
|
|
||||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
|
||||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
|
||||||
# all old entries get evicted, leaving only the newly added one
|
|
||||||
assert len(kv_prefix_cache.prompts) == 1
|
|
||||||
# The surviving entry should be the newly added one
|
|
||||||
new_tokens = encode_prompt(tokenizer, prompt)
|
|
||||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
|
||||||
new_tokens
|
|
||||||
)
|
|
||||||
16
uv.lock
generated
16
uv.lock
generated
@@ -415,7 +415,7 @@ requires-dist = [
|
|||||||
{ name = "mflux", specifier = "==0.15.4" },
|
{ name = "mflux", specifier = "==0.15.4" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||||
{ name = "mlx-lm", specifier = "==0.30.5" },
|
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||||
{ name = "psutil", specifier = ">=7.0.0" },
|
{ name = "psutil", specifier = ">=7.0.0" },
|
||||||
@@ -1072,8 +1072,8 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlx-lm"
|
name = "mlx-lm"
|
||||||
version = "0.30.5"
|
version = "0.30.4"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||||
@@ -1083,10 +1083,6 @@ dependencies = [
|
|||||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlx-metal"
|
name = "mlx-metal"
|
||||||
@@ -2285,7 +2281,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "transformers"
|
name = "transformers"
|
||||||
version = "5.0.0rc3"
|
version = "5.0.0rc2"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
@@ -2300,9 +2296,9 @@ dependencies = [
|
|||||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
|
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user