mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 07:20:14 -05:00
Compare commits
17 Commits
rust-explo
...
v1.0.66
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ce0d4602d | ||
|
|
4f24e33d30 | ||
|
|
bd4f0bf048 | ||
|
|
cd8c01b7c8 | ||
|
|
a9ee2204ef | ||
|
|
59e991ce15 | ||
|
|
ffba340e70 | ||
|
|
054b296a51 | ||
|
|
9968abe816 | ||
|
|
0e30b0830f | ||
|
|
281aaeb013 | ||
|
|
10fdc439a5 | ||
|
|
78a8c06d57 | ||
|
|
4c0c6dcae9 | ||
|
|
d885600a4c | ||
|
|
55b67e2be2 | ||
|
|
30cfad9b68 |
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -514,20 +514,6 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||
|
||||
[[package]]
|
||||
name = "cluster_membership"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"libp2p",
|
||||
"log",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
@@ -1012,7 +998,6 @@ dependencies = [
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"cluster_membership",
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
@@ -1045,12 +1030,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1159,10 +1138,7 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
"rust/cluster_membership",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -26,7 +25,6 @@ opt-level = 3
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
cluster_membership = { path = "rust/cluster_membership" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
@@ -64,7 +62,6 @@ frunk-enum-core = "0.3"
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 30
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
[package]
|
||||
name = "cluster_membership"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
# util
|
||||
anyhow.workspace = true
|
||||
log.workspace = true
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures-timer = { workspace = true }
|
||||
futures-lite = "2.6.1"
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
async-trait = "0.1.89"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -1,30 +0,0 @@
|
||||
use cluster_membership::Peer;
|
||||
use libp2p::identity::ed25519::SecretKey;
|
||||
use tokio::io::{self, AsyncBufReadExt};
|
||||
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (mut peer, send, mut recv) =
|
||||
Peer::new(SecretKey::generate(), "hello".to_string()).expect("peer should always build");
|
||||
|
||||
let ch = peer.subscribe("chatroom".to_string());
|
||||
let jh = tokio::spawn(async move { peer.run().await });
|
||||
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {send.send((ch.clone(), line.into_bytes())).await.expect("example");}
|
||||
Some(r) = recv.recv() => match r {
|
||||
Ok((_, id, line)) => println!("{:?}:{:?}", id, String::from_utf8_lossy(&line)),
|
||||
Err(e) => eprintln!("{e:?}"),
|
||||
},
|
||||
else => break
|
||||
}
|
||||
}
|
||||
jh.await.expect("task failure");
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
use libp2p::{
|
||||
Multiaddr, PeerId, Swarm, SwarmBuilder,
|
||||
futures::StreamExt,
|
||||
gossipsub::{self, PublishError, Sha256Topic, TopicHash},
|
||||
identify,
|
||||
identity::{Keypair, ed25519},
|
||||
mdns,
|
||||
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::{select, sync::mpsc};
|
||||
|
||||
const DEFAULT_BUFFER_SIZE: usize = 10;
|
||||
const MDNS_IGNORE_DURATION_SECS: u64 = 30;
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
identity: ed25519::SecretKey,
|
||||
namespace: String,
|
||||
) -> anyhow::Result<(
|
||||
Self,
|
||||
mpsc::Sender<(TopicHash, Vec<u8>)>,
|
||||
mpsc::Receiver<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||
)> {
|
||||
let mut id_bytes = identity.as_ref().to_vec();
|
||||
|
||||
let mut swarm =
|
||||
SwarmBuilder::with_existing_identity(Keypair::ed25519_from_bytes(&mut id_bytes)?)
|
||||
.with_tokio()
|
||||
.with_quic()
|
||||
// TODO(evan): .with_bandwidth_metrics();
|
||||
.with_behaviour(|kp| Behaviour::new(kp, namespace.clone()))?
|
||||
.build();
|
||||
|
||||
swarm.listen_on("/ip6/::/udp/0/quic-v1".parse()?)?;
|
||||
swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?;
|
||||
let (to_swarm, from_client) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||
let (to_client, from_swarm) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||
Ok((
|
||||
Self {
|
||||
swarm,
|
||||
namespace,
|
||||
denied: HashMap::new(),
|
||||
from_client,
|
||||
to_client,
|
||||
},
|
||||
to_swarm,
|
||||
from_swarm,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn subscribe(&mut self, topic: String) -> TopicHash {
|
||||
let topic = Sha256Topic::new(topic);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("topic filtered");
|
||||
topic.hash()
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) {
|
||||
loop {
|
||||
select! {
|
||||
ev = self.swarm.select_next_some() => {
|
||||
let Ok(()) = self.handle_swarm_event(ev).await else {
|
||||
return
|
||||
};
|
||||
},
|
||||
Some(msg) = self.from_client.recv() => {
|
||||
if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish(msg.0, msg.1) {
|
||||
let Ok(()) = self.to_client.send(Err(e)).await else {
|
||||
return
|
||||
};
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_swarm_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
|
||||
let SwarmEvent::Behaviour(event) = event else {
|
||||
if let SwarmEvent::NewListenAddr {
|
||||
listener_id: _,
|
||||
address,
|
||||
} = event
|
||||
{
|
||||
log::info!("new listen address {address}")
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
match event {
|
||||
BehaviourEvent::Mdns(mdns_event) => match mdns_event {
|
||||
mdns::Event::Discovered(vec) => {
|
||||
// Dial everyone
|
||||
let mut addrs = HashMap::<PeerId, Vec<Multiaddr>>::new();
|
||||
vec.into_iter()
|
||||
.filter(|(peer_id, _)| {
|
||||
self.denied.get(peer_id).is_none_or(|t| {
|
||||
t.elapsed() > Duration::from_secs(MDNS_IGNORE_DURATION_SECS)
|
||||
})
|
||||
})
|
||||
.for_each(|(peer_id, addr)| addrs.entry(peer_id).or_default().push(addr));
|
||||
addrs.into_iter().for_each(|(peer_id, addrs)| {
|
||||
let _ = self
|
||||
.swarm
|
||||
.dial(DialOpts::peer_id(peer_id).addresses(addrs).build());
|
||||
});
|
||||
}
|
||||
mdns::Event::Expired(vec) => {
|
||||
vec.iter().for_each(|(peer_id, _)| {
|
||||
log::debug!("{peer_id} no longer reachable on mDNS");
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_explicit_peer(peer_id);
|
||||
});
|
||||
}
|
||||
},
|
||||
BehaviourEvent::Identify(identify::Event::Received {
|
||||
connection_id: _,
|
||||
peer_id,
|
||||
info,
|
||||
}) => {
|
||||
if info
|
||||
.protocols
|
||||
.iter()
|
||||
.any(|p| p.as_ref().contains(&self.namespace))
|
||||
{
|
||||
self.passed_namespace(peer_id);
|
||||
} else {
|
||||
self.failed_namespace(peer_id);
|
||||
}
|
||||
}
|
||||
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: _,
|
||||
message_id: _,
|
||||
message:
|
||||
gossipsub::Message {
|
||||
topic,
|
||||
data,
|
||||
source: Some(source_peer),
|
||||
..
|
||||
},
|
||||
}) => {
|
||||
let Ok(()) = self.to_client.send(Ok((topic, source_peer, data))).await else {
|
||||
return Err(());
|
||||
};
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn passed_namespace(&mut self, peer: PeerId) {
|
||||
log::info!("new peer {peer:?}");
|
||||
self.denied.remove(&peer);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_blacklisted_peer(&peer);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.add_explicit_peer(&peer);
|
||||
}
|
||||
|
||||
fn failed_namespace(&mut self, peer: PeerId) {
|
||||
log::debug!("{peer} failed handshake");
|
||||
self.denied.insert(peer, Instant::now());
|
||||
self.swarm.behaviour_mut().gossipsub.blacklist_peer(&peer);
|
||||
// we don't care if disconnect fails
|
||||
let _ = self.swarm.disconnect_peer_id(peer);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
pub swarm: Swarm<Behaviour>,
|
||||
denied: HashMap<PeerId, Instant>,
|
||||
namespace: String,
|
||||
to_client: mpsc::Sender<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||
from_client: mpsc::Receiver<(TopicHash, Vec<u8>)>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn foo() {
|
||||
fn bar<T: Send>(t: T) {}
|
||||
let p: Peer = unimplemented!();
|
||||
bar(p);
|
||||
}
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
identify: identify::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
fn new(kp: &Keypair, namespace: String) -> Self {
|
||||
let mdns = mdns::tokio::Behaviour::new(Default::default(), kp.public().to_peer_id())
|
||||
.expect("implementation is infallible");
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(kp.clone()),
|
||||
gossipsub::ConfigBuilder::default()
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.protocol_id_prefix(format!("/exo/gossip/{namespace}/v1"))
|
||||
.build()
|
||||
.expect("fixed gossipsub config should always build"),
|
||||
)
|
||||
.expect("fixed gossipsub init should always build");
|
||||
|
||||
let identify = identify::Behaviour::new(
|
||||
identify::Config::new_with_signed_peer_record(format!("/exo/identity/v1"), kp)
|
||||
.with_push_listen_addr_updates(true),
|
||||
);
|
||||
|
||||
Behaviour {
|
||||
mdns,
|
||||
gossipsub,
|
||||
identify,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
cluster_membership.workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
|
||||
@@ -6,41 +6,3 @@
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use cluster_membership::Peer;
|
||||
use libp2p::identity::ed25519::Keypair;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct PyKeypair(Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
pub struct PyPeer(Mutex<Peer>);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyPeer {
|
||||
#[staticmethod]
|
||||
fn init(kp: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
Ok(PyPeer(Mutex::new(
|
||||
Peer::new(kp.0.secret(), namespace)
|
||||
.map_err(|e| e.pyerr())?
|
||||
.0,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,11 +121,20 @@ async def ensure_models_dir() -> Path:
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
models_dir = await ensure_models_dir()
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
deleted = False
|
||||
if await aios.path.exists(model_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
deleted = True
|
||||
|
||||
# Also clear cache
|
||||
if await aios.path.exists(cache_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
async def seed_models(seed_dir: str | Path):
|
||||
@@ -151,16 +160,28 @@ async def fetch_file_list_with_cache(
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
return file_list
|
||||
|
||||
# Always try fresh first
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
# Update cache with fresh data
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# No cache available, propagate the error
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
@@ -332,8 +353,28 @@ async def _download_file(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
) -> Path:
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
try:
|
||||
remote_size, _ = await file_meta(model_id, revision, path)
|
||||
if local_size != remote_size:
|
||||
logger.info(
|
||||
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
|
||||
)
|
||||
await aios.remove(target_path)
|
||||
else:
|
||||
return target_path
|
||||
except Exception as e:
|
||||
# Offline or network error - trust local file
|
||||
logger.debug(
|
||||
f"Could not verify {path} against remote (offline?): {e}, using local file"
|
||||
)
|
||||
return target_path
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -542,17 +583,26 @@ async def download_shard(
|
||||
async def on_progress_wrapper(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
start_time = (
|
||||
file_progress[file.path].start_time
|
||||
if file.path in file_progress
|
||||
else time.time()
|
||||
)
|
||||
downloaded_this_session = (
|
||||
file_progress[file.path].downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
|
||||
if file.path in file_progress
|
||||
else curr_bytes
|
||||
previous_progress = file_progress.get(file.path)
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
speed = (
|
||||
downloaded_this_session / (time.time() - start_time)
|
||||
if time.time() - start_time > 0
|
||||
|
||||
0
src/exo/download/tests/__init__.py
Normal file
0
src/exo/download/tests/__init__.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""Tests for download verification and cache behavior."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from exo.download.download_utils import (
|
||||
delete_model,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
"""Set up a temporary models directory for testing."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestFileVerification:
|
||||
"""Tests for file size verification in _download_file."""
|
||||
|
||||
async def test_redownload_when_file_size_changes_upstream(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with mismatched sizes are re-downloaded."""
|
||||
# Import inside test to allow patching
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file with wrong size
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content") # 13 bytes
|
||||
|
||||
remote_size = 1000 # Different from local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
# Set up mock HTTP response for re-download
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
|
||||
side_effect=[b"x" * remote_size, b""]
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Mock calc_hash to return the expected hash
|
||||
with patch(
|
||||
"exo.download.download_utils.calc_hash",
|
||||
new_callable=AsyncMock,
|
||||
return_value=remote_hash,
|
||||
):
|
||||
await _download_file(model_id, "main", "test.safetensors", target_dir)
|
||||
|
||||
# file_meta should be called twice: once for verification, once for download
|
||||
assert mock_file_meta.call_count == 2
|
||||
|
||||
async def test_skip_download_when_file_size_matches(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with matching sizes are not re-downloaded."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
local_content = b"local content"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(local_content)
|
||||
|
||||
remote_size = len(local_content) # Same as local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return immediately without downloading
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_called_once()
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
async def test_offline_fallback_uses_local_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that local files are used when network is unavailable."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return local file without attempting download
|
||||
assert result == local_file
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFileListCache:
|
||||
"""Tests for file list caching behavior."""
|
||||
|
||||
async def test_fetch_fresh_and_update_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that fresh data is fetched and cache is updated."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=100),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=file_list,
|
||||
) as mock_fetch,
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == file_list
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
# Verify cache was written
|
||||
cache_file = (
|
||||
models_dir
|
||||
/ "caches"
|
||||
/ model_id.normalize()
|
||||
/ f"{model_id.normalize()}--main--file_list.json"
|
||||
)
|
||||
assert await aios.path.exists(cache_file)
|
||||
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
|
||||
await f.read()
|
||||
)
|
||||
assert cached_data == file_list
|
||||
|
||||
async def test_fallback_to_cache_when_fetch_fails(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that cached data is used when fetch fails."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create cache file
|
||||
cached_file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == cached_file_list
|
||||
|
||||
async def test_error_propagates_when_no_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that errors propagate when fetch fails and no cache exists."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
pytest.raises(Exception, match="Network error"),
|
||||
):
|
||||
await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
|
||||
class TestModelDeletion:
|
||||
"""Tests for model deletion including cache cleanup."""
|
||||
|
||||
async def test_delete_model_clears_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that deleting a model also deletes its cache."""
|
||||
models_dir = tmp_path / "models"
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Create model and cache directories
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Add some files
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
|
||||
await f.write("model data")
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is True
|
||||
assert not await aios.path.exists(model_dir)
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_model_only_cache_exists(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting when only cache exists (model already deleted)."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Only create cache directory
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
# Returns False because model dir didn't exist
|
||||
assert result is False
|
||||
# But cache should still be cleaned up
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_nonexistent_model(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a model that doesn't exist."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestProgressResetOnRedownload:
|
||||
"""Tests for progress tracking when files are re-downloaded."""
|
||||
|
||||
async def test_progress_resets_correctly_on_redownload(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress tracking resets when a file is re-downloaded.
|
||||
|
||||
When a file is deleted and re-downloaded (due to size mismatch),
|
||||
the progress tracking should reset rather than calculating negative
|
||||
downloaded_this_session values.
|
||||
"""
|
||||
# Simulate file_progress dict as it exists in download_shard
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with old file progress (simulating existing large file)
|
||||
old_file_size = 1_500_000_000 # 1.5 GB
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(old_file_size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(old_file_size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="not_started",
|
||||
start_time=time.time() - 10, # Started 10 seconds ago
|
||||
)
|
||||
|
||||
# Simulate the logic from on_progress_wrapper after re-download starts
|
||||
# This is the exact logic from the fixed on_progress_wrapper
|
||||
curr_bytes = 100_000 # 100 KB - new download just started
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is True, "Should detect re-download scenario"
|
||||
assert downloaded_this_session == curr_bytes, (
|
||||
"downloaded_this_session should equal curr_bytes on re-download"
|
||||
)
|
||||
assert downloaded_this_session > 0, (
|
||||
"downloaded_this_session should be positive, not negative"
|
||||
)
|
||||
|
||||
# Calculate speed (should be positive)
|
||||
elapsed = time.time() - start_time
|
||||
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
|
||||
assert speed >= 0, "Speed should be non-negative"
|
||||
|
||||
async def test_progress_accumulates_on_continuing_download(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress accumulates correctly for continuing downloads.
|
||||
|
||||
When a download continues from where it left off (resume),
|
||||
the progress should accumulate correctly.
|
||||
"""
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with partial download progress
|
||||
initial_downloaded = 500_000 # 500 KB already downloaded
|
||||
start_time = time.time() - 5 # Started 5 seconds ago
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(initial_downloaded),
|
||||
downloaded_this_session=Memory.from_bytes(initial_downloaded),
|
||||
total=Memory.from_bytes(1_000_000),
|
||||
speed=100_000,
|
||||
eta=timedelta(seconds=5),
|
||||
status="in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# Progress callback with more bytes downloaded
|
||||
curr_bytes = 600_000 # 600 KB - continuing download
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# This is NOT a re-download (curr_bytes > previous downloaded)
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
downloaded_this_session = curr_bytes
|
||||
used_start_time = time.time()
|
||||
else:
|
||||
used_start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is False, (
|
||||
"Should NOT detect re-download for continuing download"
|
||||
)
|
||||
assert used_start_time == start_time, "Should preserve original start_time"
|
||||
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
|
||||
assert downloaded_this_session == expected_session, (
|
||||
f"Should accumulate: {downloaded_this_session} == {expected_session}"
|
||||
)
|
||||
assert downloaded_this_session == 600_000, (
|
||||
"downloaded_this_session should equal total downloaded so far"
|
||||
)
|
||||
@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -428,7 +428,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -442,7 +442,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -457,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -470,7 +470,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -484,7 +484,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -499,7 +499,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -541,9 +541,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -551,10 +551,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -575,9 +575,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -585,10 +585,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -610,6 +610,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
@@ -145,6 +147,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -252,10 +258,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
@@ -334,15 +336,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -351,7 +345,6 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -367,6 +360,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GLM4MoeLiteModel):
|
||||
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
@@ -441,7 +442,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -516,6 +517,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -533,6 +536,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers), # type: ignore
|
||||
on_timeout,
|
||||
)
|
||||
if layer.self_attn.q_lora_rank is None: # type: ignore
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -566,7 +647,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -607,6 +688,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -661,7 +743,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -170,10 +170,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -405,7 +405,11 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
message.content = "\n".join(c.text for c in message.content).strip()
|
||||
if message.content is None and message.thinking is None:
|
||||
if (
|
||||
message.content is None
|
||||
and message.thinking is None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
|
||||
52
uv.lock
generated
52
uv.lock
generated
@@ -376,8 +376,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -412,10 +412,10 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -994,8 +994,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1022,18 +1022,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
|
||||
]
|
||||
@@ -1046,6 +1040,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.4.dev20260121+fbe306f9"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.3"
|
||||
@@ -1072,26 +1074,20 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.4"
|
||||
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2281,7 +2277,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "5.0.0rc2"
|
||||
version = "5.0.0rc3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2296,9 +2292,9 @@ dependencies = [
|
||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user