mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-16 01:51:03 -05:00
Compare commits
6 Commits
main
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
563e94ab6c | ||
|
|
41c27832d9 | ||
|
|
9c04350e55 | ||
|
|
1ca732da23 | ||
|
|
ba708c2ccd | ||
|
|
c9d95859e8 |
1605
Cargo.lock
generated
1605
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/downloads",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
@@ -25,6 +26,7 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
downloads = { path = "rust/downloads" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
|
||||
@@ -199,7 +199,13 @@
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
|
||||
?? (downloadPayload as Record<string, unknown>).downloadProgress
|
||||
?? {};
|
||||
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
|
||||
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
|
||||
const totalBytes = getBytes(
|
||||
(downloadPayload as Record<string, unknown>).total_bytes
|
||||
?? (downloadPayload as Record<string, unknown>).totalBytes
|
||||
?? (rawProgress as Record<string, unknown>).total_bytes
|
||||
?? (rawProgress as Record<string, unknown>).totalBytes
|
||||
);
|
||||
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
|
||||
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
|
||||
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
|
||||
@@ -332,8 +338,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -385,7 +396,7 @@
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
|
||||
40
rust/downloads/Cargo.toml
Normal file
40
rust/downloads/Cargo.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "downloads"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "downloads"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# macro dependencies
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
itertools = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# BitTorrent library
|
||||
librqbit = { git = "https://github.com/JakeHillion/rqbit", rev = "c4e2ecf81d03bd8acd96a0803d06a70b34d5da19" }
|
||||
|
||||
# Embed torrent files
|
||||
include_dir = "0.7"
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
162
rust/downloads/src/bencode.rs
Normal file
162
rust/downloads/src/bencode.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Bencode encoding for BitTorrent tracker responses
|
||||
//!
|
||||
//! Implements the subset of bencoding needed for tracker announce responses.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Parameters from a tracker announce request
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnounceParams {
|
||||
/// 20-byte info hash of the torrent
|
||||
pub info_hash: [u8; 20],
|
||||
/// 20-byte peer ID of the client
|
||||
pub peer_id: [u8; 20],
|
||||
/// Port the client is listening on
|
||||
pub port: u16,
|
||||
/// Total bytes uploaded
|
||||
pub uploaded: u64,
|
||||
/// Total bytes downloaded
|
||||
pub downloaded: u64,
|
||||
/// Bytes remaining to download
|
||||
pub left: u64,
|
||||
/// Whether to return compact peer list (6 bytes per peer)
|
||||
pub compact: bool,
|
||||
/// Optional event (started, stopped, completed)
|
||||
pub event: Option<AnnounceEvent>,
|
||||
}
|
||||
|
||||
/// Announce event types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AnnounceEvent {
|
||||
Started,
|
||||
Stopped,
|
||||
Completed,
|
||||
}
|
||||
|
||||
/// A bencoded value
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BencodeValue {
|
||||
Integer(i64),
|
||||
Bytes(Vec<u8>),
|
||||
List(Vec<BencodeValue>),
|
||||
Dict(BTreeMap<Vec<u8>, BencodeValue>),
|
||||
}
|
||||
|
||||
impl BencodeValue {
|
||||
/// Create a string value from a &str
|
||||
#[inline]
|
||||
pub fn string(s: &str) -> Self {
|
||||
Self::Bytes(s.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
/// Create an integer value
|
||||
#[inline]
|
||||
pub fn integer(i: i64) -> Self {
|
||||
Self::Integer(i)
|
||||
}
|
||||
|
||||
/// Create an empty list
|
||||
#[inline]
|
||||
pub fn list() -> Self {
|
||||
Self::List(Vec::new())
|
||||
}
|
||||
|
||||
/// Create an empty dict
|
||||
#[inline]
|
||||
pub fn dict() -> Self {
|
||||
Self::Dict(BTreeMap::new())
|
||||
}
|
||||
|
||||
/// Add an item to a list (builder pattern)
|
||||
#[inline]
|
||||
pub fn push(mut self, value: BencodeValue) -> Self {
|
||||
if let Self::List(ref mut list) = self {
|
||||
list.push(value);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Insert a key-value pair into a dict (builder pattern)
|
||||
#[inline]
|
||||
pub fn insert(mut self, key: &str, value: BencodeValue) -> Self {
|
||||
if let Self::Dict(ref mut dict) = self {
|
||||
dict.insert(key.as_bytes().to_vec(), value);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Encode to bencoded bytes
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
self.encode_into(&mut buf);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode into an existing buffer
|
||||
pub fn encode_into(&self, buf: &mut Vec<u8>) {
|
||||
match self {
|
||||
Self::Integer(i) => {
|
||||
buf.push(b'i');
|
||||
buf.extend_from_slice(i.to_string().as_bytes());
|
||||
buf.push(b'e');
|
||||
}
|
||||
Self::Bytes(bytes) => {
|
||||
buf.extend_from_slice(bytes.len().to_string().as_bytes());
|
||||
buf.push(b':');
|
||||
buf.extend_from_slice(bytes);
|
||||
}
|
||||
Self::List(list) => {
|
||||
buf.push(b'l');
|
||||
for item in list {
|
||||
item.encode_into(buf);
|
||||
}
|
||||
buf.push(b'e');
|
||||
}
|
||||
Self::Dict(dict) => {
|
||||
buf.push(b'd');
|
||||
// BTreeMap keeps keys sorted
|
||||
for (key, value) in dict {
|
||||
buf.extend_from_slice(key.len().to_string().as_bytes());
|
||||
buf.push(b':');
|
||||
buf.extend_from_slice(key);
|
||||
value.encode_into(buf);
|
||||
}
|
||||
buf.push(b'e');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encode_integer() {
|
||||
assert_eq!(BencodeValue::integer(42).encode(), b"i42e");
|
||||
assert_eq!(BencodeValue::integer(-1).encode(), b"i-1e");
|
||||
assert_eq!(BencodeValue::integer(0).encode(), b"i0e");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_string() {
|
||||
assert_eq!(BencodeValue::string("spam").encode(), b"4:spam");
|
||||
assert_eq!(BencodeValue::string("").encode(), b"0:");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_list() {
|
||||
let list = BencodeValue::list()
|
||||
.push(BencodeValue::string("spam"))
|
||||
.push(BencodeValue::integer(42));
|
||||
assert_eq!(list.encode(), b"l4:spami42ee");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_dict() {
|
||||
let dict = BencodeValue::dict()
|
||||
.insert("bar", BencodeValue::string("spam"))
|
||||
.insert("foo", BencodeValue::integer(42));
|
||||
assert_eq!(dict.encode(), b"d3:bar4:spam3:fooi42ee");
|
||||
}
|
||||
}
|
||||
108
rust/downloads/src/embedded.rs
Normal file
108
rust/downloads/src/embedded.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! Embedded torrent file access
|
||||
//!
|
||||
//! Provides access to .torrent files embedded in the binary at compile time.
|
||||
//! Each model/revision can have multiple torrent variants (e.g., "small", "large").
|
||||
|
||||
use include_dir::{Dir, include_dir};
|
||||
|
||||
/// Embedded torrent files directory
|
||||
static TORRENTS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/torrents");
|
||||
|
||||
/// Get all embedded torrent variants for a model_id and revision
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_id` - Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
|
||||
/// * `revision` - Git commit hash
|
||||
///
|
||||
/// # Returns
|
||||
/// Vec of (variant_name, torrent_data) tuples, e.g., [("small", data), ("large", data)]
|
||||
/// Returns empty Vec if no torrents found for this model/revision.
|
||||
#[inline]
|
||||
pub fn get_embedded_torrents(model_id: &str, revision: &str) -> Vec<(String, Vec<u8>)> {
|
||||
let dir_path = format!("{model_id}");
|
||||
|
||||
let Some(model_dir) = TORRENTS.get_dir(&dir_path) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
let prefix = format!("{revision}.");
|
||||
let suffix = ".torrent";
|
||||
|
||||
for file in model_dir.files() {
|
||||
let Some(name) = file.path().file_name().and_then(|n| n.to_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Match files like "{revision}.small.torrent" or "{revision}.large.torrent"
|
||||
if name.starts_with(&prefix) && name.ends_with(suffix) {
|
||||
// Extract variant: "{revision}.{variant}.torrent" -> "{variant}"
|
||||
let middle = &name[prefix.len()..name.len() - suffix.len()];
|
||||
|
||||
// Skip plain "{revision}.torrent" files (wrong format)
|
||||
if middle.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
results.push((middle.to_string(), file.contents().to_vec()));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by variant name for consistent ordering
|
||||
results.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_embedded_torrents() {
|
||||
// Test with the Qwen3 torrent we have
|
||||
let result = get_embedded_torrents(
|
||||
"mlx-community/Qwen3-30B-A3B-4bit",
|
||||
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
|
||||
);
|
||||
|
||||
assert!(!result.is_empty(), "Expected to find embedded torrents");
|
||||
|
||||
// Should have both small and large variants
|
||||
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
|
||||
assert!(
|
||||
variants.contains(&"small"),
|
||||
"Expected 'small' variant, got: {variants:?}"
|
||||
);
|
||||
assert!(
|
||||
variants.contains(&"large"),
|
||||
"Expected 'large' variant, got: {variants:?}"
|
||||
);
|
||||
|
||||
// Verify data is not empty
|
||||
for (variant, data) in &result {
|
||||
assert!(!data.is_empty(), "Torrent data for '{variant}' should not be empty");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_torrent() {
|
||||
let result = get_embedded_torrents("nonexistent/model", "abc123");
|
||||
assert!(result.is_empty(), "Expected empty Vec for missing torrent");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variant_ordering() {
|
||||
let result = get_embedded_torrents(
|
||||
"mlx-community/Qwen3-30B-A3B-4bit",
|
||||
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
|
||||
);
|
||||
|
||||
if result.len() >= 2 {
|
||||
// Verify alphabetical ordering
|
||||
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
|
||||
let mut sorted = variants.clone();
|
||||
sorted.sort();
|
||||
assert_eq!(variants, sorted, "Variants should be sorted alphabetically");
|
||||
}
|
||||
}
|
||||
}
|
||||
22
rust/downloads/src/lib.rs
Normal file
22
rust/downloads/src/lib.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! BitTorrent-based download system for model shards using rqbit
|
||||
//!
|
||||
//! This crate provides:
|
||||
//! - Torrent session management via rqbit
|
||||
//! - Embedded torrent file access
|
||||
//! - Private tracker announce handling
|
||||
//! - Selective file download based on shard layer ranges
|
||||
|
||||
#![allow(clippy::missing_inline_in_public_items)]
|
||||
|
||||
pub mod bencode;
|
||||
pub mod embedded;
|
||||
pub mod progress;
|
||||
pub mod session;
|
||||
pub mod torrent_files;
|
||||
pub mod tracker;
|
||||
|
||||
pub use bencode::AnnounceParams;
|
||||
pub use embedded::get_embedded_torrents;
|
||||
pub use session::{DownloadProgress, TorrentSession};
|
||||
pub use torrent_files::{get_torrent_file_list, TorrentFileInfo};
|
||||
pub use tracker::{handle_announce, PeerInfo, TopologyData};
|
||||
77
rust/downloads/src/progress.rs
Normal file
77
rust/downloads/src/progress.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
//! Download progress tracking
|
||||
//!
|
||||
//! Types for tracking and reporting download progress to Python
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Progress update for a torrent download
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DownloadProgress {
|
||||
/// Total bytes to download
|
||||
pub total_bytes: u64,
|
||||
|
||||
/// Bytes downloaded so far
|
||||
pub downloaded_bytes: u64,
|
||||
|
||||
/// Number of pieces completed
|
||||
pub pieces_completed: usize,
|
||||
|
||||
/// Total number of pieces
|
||||
pub total_pieces: usize,
|
||||
|
||||
/// Number of peers connected
|
||||
pub peers_connected: usize,
|
||||
|
||||
/// Download speed in bytes/second
|
||||
pub speed_bytes_per_sec: f64,
|
||||
|
||||
/// Estimated time remaining in seconds
|
||||
pub eta_seconds: Option<f64>,
|
||||
|
||||
/// Per-file progress
|
||||
pub files: HashMap<String, FileProgress>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileProgress {
|
||||
/// Total file size
|
||||
pub total_bytes: u64,
|
||||
|
||||
/// Bytes downloaded for this file
|
||||
pub downloaded_bytes: u64,
|
||||
|
||||
/// Whether the file is complete
|
||||
pub complete: bool,
|
||||
}
|
||||
|
||||
impl DownloadProgress {
|
||||
#[inline]
|
||||
pub fn new(total_bytes: u64, total_pieces: usize) -> Self {
|
||||
Self {
|
||||
total_bytes,
|
||||
downloaded_bytes: 0,
|
||||
pieces_completed: 0,
|
||||
total_pieces,
|
||||
peers_connected: 0,
|
||||
speed_bytes_per_sec: 0.0,
|
||||
eta_seconds: None,
|
||||
files: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn progress_fraction(&self) -> f64 {
|
||||
if self.total_bytes == 0 {
|
||||
0.0
|
||||
} else {
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let fraction = self.downloaded_bytes as f64 / self.total_bytes as f64;
|
||||
fraction
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.pieces_completed >= self.total_pieces
|
||||
}
|
||||
}
|
||||
166
rust/downloads/src/session.rs
Normal file
166
rust/downloads/src/session.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
//! Torrent session management using rqbit
|
||||
//!
|
||||
//! Provides a wrapper around rqbit's Session for managing torrent downloads
|
||||
//! with persistent seeding and selective file downloads.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use librqbit::{AddTorrent, AddTorrentOptions, AddTorrentResponse, Api, ManagedTorrent, Session, SessionOptions, SessionPersistenceConfig};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Download progress information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DownloadProgress {
|
||||
pub downloaded_bytes: u64,
|
||||
pub total_bytes: u64,
|
||||
pub download_speed: f64,
|
||||
pub upload_speed: f64,
|
||||
pub peers_connected: usize,
|
||||
pub is_finished: bool,
|
||||
}
|
||||
|
||||
/// Torrent session handle for managing multiple torrents
|
||||
pub struct TorrentSession {
|
||||
session: Arc<Session>,
|
||||
api: Arc<Api>,
|
||||
session_dir: PathBuf,
|
||||
torrents: Arc<RwLock<HashMap<String, Arc<ManagedTorrent>>>>,
|
||||
}
|
||||
|
||||
impl TorrentSession {
|
||||
/// Create a new torrent session
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `session_dir` - Directory to store session state and downloaded files
|
||||
pub async fn new(session_dir: PathBuf) -> Result<Self> {
|
||||
std::fs::create_dir_all(&session_dir).context("Failed to create session directory")?;
|
||||
|
||||
let opts = SessionOptions {
|
||||
disable_dht: false,
|
||||
disable_dht_persistence: false,
|
||||
dht_config: None,
|
||||
persistence: Some(SessionPersistenceConfig::Json { folder: None }),
|
||||
fastresume: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let session = Session::new_with_opts(session_dir.clone(), opts)
|
||||
.await
|
||||
.context("Failed to create rqbit session")?;
|
||||
|
||||
let api = Api::new(Arc::clone(&session), None);
|
||||
|
||||
Ok(Self {
|
||||
session,
|
||||
api: Arc::new(api),
|
||||
session_dir,
|
||||
torrents: Arc::new(RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a torrent from raw bytes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `torrent_data` - Raw .torrent file contents
|
||||
/// * `save_path` - Where to save the downloaded files
|
||||
/// * `file_indices` - Optional list of file indices to download (None = all files)
|
||||
///
|
||||
/// # Returns
|
||||
/// Info hash as hex string
|
||||
pub async fn add_torrent(
|
||||
&self,
|
||||
torrent_data: Vec<u8>,
|
||||
save_path: PathBuf,
|
||||
file_indices: Option<Vec<usize>>,
|
||||
) -> Result<String> {
|
||||
let opts = AddTorrentOptions {
|
||||
overwrite: false,
|
||||
only_files_regex: None,
|
||||
only_files: file_indices,
|
||||
output_folder: Some(save_path.to_string_lossy().to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let add_torrent = AddTorrent::from_bytes(torrent_data);
|
||||
|
||||
let response = self
|
||||
.session
|
||||
.add_torrent(add_torrent, Some(opts))
|
||||
.await
|
||||
.context("Failed to add torrent")?;
|
||||
|
||||
let handle = match response {
|
||||
AddTorrentResponse::Added(_, handle) => handle,
|
||||
AddTorrentResponse::AlreadyManaged(_, handle) => handle,
|
||||
AddTorrentResponse::ListOnly(_) => anyhow::bail!("Torrent was list-only, not added"),
|
||||
};
|
||||
|
||||
let info_hash = handle.info_hash().as_string();
|
||||
|
||||
self.torrents
|
||||
.write()
|
||||
.await
|
||||
.insert(info_hash.clone(), handle);
|
||||
|
||||
Ok(info_hash)
|
||||
}
|
||||
|
||||
/// Get download progress for a torrent
|
||||
pub async fn get_progress(&self, info_hash: &str) -> Result<DownloadProgress> {
|
||||
let torrents = self.torrents.read().await;
|
||||
let handle = torrents.get(info_hash).context("Torrent not found")?;
|
||||
|
||||
let stats = handle.stats();
|
||||
|
||||
Ok(DownloadProgress {
|
||||
downloaded_bytes: stats.progress_bytes,
|
||||
total_bytes: stats.total_bytes,
|
||||
download_speed: stats.live.as_ref().map_or(0.0, |l| l.download_speed.mbps * 1024.0 * 1024.0),
|
||||
upload_speed: stats.live.as_ref().map_or(0.0, |l| l.upload_speed.mbps * 1024.0 * 1024.0),
|
||||
peers_connected: stats.live.as_ref().map_or(0, |l| l.snapshot.peer_stats.live as usize),
|
||||
is_finished: stats.finished,
|
||||
})
|
||||
}
|
||||
|
||||
/// Wait until torrent download is completed
|
||||
pub async fn wait_until_completed(&self, info_hash: &str) -> Result<()> {
|
||||
let torrents = self.torrents.read().await;
|
||||
let handle = torrents.get(info_hash).context("Torrent not found")?;
|
||||
|
||||
handle
|
||||
.wait_until_completed()
|
||||
.await
|
||||
.context("Failed to wait for completion")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable seeding for a completed torrent
|
||||
///
|
||||
/// Note: rqbit seeds by default after completion, this is a no-op
|
||||
/// but kept for API compatibility
|
||||
pub async fn enable_seeding(&self, _info_hash: &str) -> Result<()> {
|
||||
// rqbit automatically seeds after download completion
|
||||
// This is kept for API compatibility
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a torrent from the session
|
||||
pub async fn remove_torrent(&self, info_hash: &str) -> Result<()> {
|
||||
let mut torrents = self.torrents.write().await;
|
||||
|
||||
if let Some(handle) = torrents.remove(info_hash) {
|
||||
drop(handle);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get list of all torrent info hashes in the session
|
||||
pub async fn list_torrents(&self) -> Vec<String> {
|
||||
self.torrents.read().await.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
100
rust/downloads/src/torrent_files.rs
Normal file
100
rust/downloads/src/torrent_files.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! Torrent file list parsing
|
||||
//!
|
||||
//! Provides functionality to extract file information from torrent metadata
|
||||
//! without adding the torrent to a session.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use librqbit::torrent_from_bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Information about a file in a torrent
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TorrentFileInfo {
|
||||
/// File index (0-based)
|
||||
pub index: usize,
|
||||
/// File path relative to torrent root
|
||||
pub path: String,
|
||||
/// File size in bytes
|
||||
pub size: u64,
|
||||
}
|
||||
|
||||
/// Get the list of files in a torrent from its raw bytes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `torrent_data` - Raw .torrent file contents
|
||||
///
|
||||
/// # Returns
|
||||
/// List of file information (index, path, size)
|
||||
pub fn get_torrent_file_list(torrent_data: &[u8]) -> Result<Vec<TorrentFileInfo>> {
|
||||
let torrent_meta = torrent_from_bytes(torrent_data).context("Failed to parse torrent")?;
|
||||
|
||||
// Access the data inside WithRawBytes wrapper
|
||||
let info = &torrent_meta.info.data;
|
||||
|
||||
let mut files = Vec::new();
|
||||
|
||||
// Handle both single-file and multi-file torrents
|
||||
if let Some(ref file_list) = info.files {
|
||||
// Multi-file torrent
|
||||
for (index, file) in file_list.iter().enumerate() {
|
||||
let path = file
|
||||
.path
|
||||
.iter()
|
||||
.map(|buf| String::from_utf8_lossy(buf.0).to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("/");
|
||||
|
||||
files.push(TorrentFileInfo {
|
||||
index,
|
||||
path,
|
||||
size: file.length,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Single-file torrent
|
||||
let name = match &info.name {
|
||||
Some(n) => String::from_utf8_lossy(n.0).to_string(),
|
||||
None => String::new(),
|
||||
};
|
||||
files.push(TorrentFileInfo {
|
||||
index: 0,
|
||||
path: name,
|
||||
size: info.length.unwrap_or(0),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::get_embedded_torrents;
|
||||
|
||||
#[test]
|
||||
fn test_get_torrent_file_list() {
|
||||
// Use an embedded torrent for testing
|
||||
let torrents = get_embedded_torrents(
|
||||
"mlx-community/Qwen3-30B-A3B-4bit",
|
||||
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
|
||||
);
|
||||
|
||||
assert!(!torrents.is_empty(), "Expected to find embedded torrents");
|
||||
|
||||
for (variant, data) in torrents {
|
||||
let files = get_torrent_file_list(&data).expect("Failed to parse torrent");
|
||||
assert!(!files.is_empty(), "Expected files in {variant} variant");
|
||||
|
||||
// Verify file info makes sense
|
||||
for file in &files {
|
||||
assert!(!file.path.is_empty(), "File path should not be empty");
|
||||
assert!(file.size > 0, "File size should be positive");
|
||||
}
|
||||
|
||||
println!("Variant '{variant}' has {} files", files.len());
|
||||
for file in files.iter().take(5) {
|
||||
println!(" [{}] {} ({} bytes)", file.index, file.path, file.size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
185
rust/downloads/src/tracker.rs
Normal file
185
rust/downloads/src/tracker.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
//! Fake tracker implementation for Exo topology-based peer discovery
|
||||
//!
|
||||
//! Instead of contacting real BitTorrent trackers, this module generates
|
||||
//! tracker announce responses using Exo's cluster topology data.
|
||||
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::bencode::{AnnounceParams, BencodeValue};
|
||||
|
||||
/// Information about a peer in the Exo topology
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PeerInfo {
|
||||
/// Unique node identifier in the Exo cluster
|
||||
pub node_id: String,
|
||||
/// IPv4 address of the peer
|
||||
pub ip: Ipv4Addr,
|
||||
/// BitTorrent listening port
|
||||
pub port: u16,
|
||||
/// Whether this peer has the complete torrent
|
||||
pub has_complete: bool,
|
||||
/// Priority for peer selection (higher = prefer)
|
||||
pub priority: i32,
|
||||
}
|
||||
|
||||
/// Topology data containing available peers
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TopologyData {
|
||||
/// List of peers in the topology
|
||||
pub peers: Vec<PeerInfo>,
|
||||
}
|
||||
|
||||
/// Default announce interval in seconds
|
||||
const DEFAULT_INTERVAL: i64 = 1800;
|
||||
|
||||
/// Handle a tracker announce request using Exo topology data
|
||||
///
|
||||
/// Returns a bencoded tracker response containing peers from the topology.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `params` - Announce request parameters
|
||||
/// * `topology` - Current Exo cluster topology
|
||||
///
|
||||
/// # Returns
|
||||
/// Bencoded announce response as bytes
|
||||
pub fn handle_announce(params: &AnnounceParams, topology: &TopologyData) -> Result<Vec<u8>> {
|
||||
// Sort peers by priority (descending) for better peer selection
|
||||
let mut peers: Vec<_> = topology.peers.iter().collect();
|
||||
peers.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
|
||||
let response = if params.compact {
|
||||
// Compact format: 6 bytes per peer (4 IP + 2 port)
|
||||
let mut peer_data = Vec::with_capacity(peers.len() * 6);
|
||||
for peer in &peers {
|
||||
peer_data.extend_from_slice(&peer.ip.octets());
|
||||
peer_data.extend_from_slice(&peer.port.to_be_bytes());
|
||||
}
|
||||
|
||||
BencodeValue::dict()
|
||||
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
|
||||
.insert("peers", BencodeValue::Bytes(peer_data))
|
||||
} else {
|
||||
// Non-compact format: list of dicts
|
||||
let mut peer_list = BencodeValue::list();
|
||||
for peer in &peers {
|
||||
let peer_dict = BencodeValue::dict()
|
||||
.insert("ip", BencodeValue::string(&peer.ip.to_string()))
|
||||
.insert("port", BencodeValue::integer(i64::from(peer.port)))
|
||||
.insert("peer id", BencodeValue::Bytes(vec![0u8; 20])); // Placeholder peer ID
|
||||
peer_list = peer_list.push(peer_dict);
|
||||
}
|
||||
|
||||
BencodeValue::dict()
|
||||
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
|
||||
.insert("peers", peer_list)
|
||||
};
|
||||
|
||||
Ok(response.encode())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_params(compact: bool) -> AnnounceParams {
|
||||
AnnounceParams {
|
||||
info_hash: [0u8; 20],
|
||||
peer_id: [0u8; 20],
|
||||
port: 6881,
|
||||
uploaded: 0,
|
||||
downloaded: 0,
|
||||
left: 1000,
|
||||
compact,
|
||||
event: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_test_topology() -> TopologyData {
|
||||
TopologyData {
|
||||
peers: vec![
|
||||
PeerInfo {
|
||||
node_id: "node1".to_string(),
|
||||
ip: Ipv4Addr::new(192, 168, 1, 1),
|
||||
port: 6881,
|
||||
has_complete: true,
|
||||
priority: 10,
|
||||
},
|
||||
PeerInfo {
|
||||
node_id: "node2".to_string(),
|
||||
ip: Ipv4Addr::new(192, 168, 1, 2),
|
||||
port: 6882,
|
||||
has_complete: false,
|
||||
priority: 5,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compact_response() {
|
||||
let params = make_test_params(true);
|
||||
let topology = make_test_topology();
|
||||
|
||||
let response = handle_announce(¶ms, &topology).unwrap();
|
||||
|
||||
// Should contain "interval" and "peers" keys
|
||||
assert!(response.starts_with(b"d"));
|
||||
assert!(response.ends_with(b"e"));
|
||||
|
||||
// Verify we have 12 bytes of peer data (2 peers * 6 bytes)
|
||||
// The compact peers field should be "12:<12 bytes>"
|
||||
let response_str = String::from_utf8_lossy(&response);
|
||||
assert!(response_str.contains("8:interval"));
|
||||
assert!(response_str.contains("5:peers"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_compact_response() {
|
||||
let params = make_test_params(false);
|
||||
let topology = make_test_topology();
|
||||
|
||||
let response = handle_announce(¶ms, &topology).unwrap();
|
||||
|
||||
// Should contain peers as a list
|
||||
let response_str = String::from_utf8_lossy(&response);
|
||||
assert!(response_str.contains("8:interval"));
|
||||
assert!(response_str.contains("5:peers"));
|
||||
assert!(response_str.contains("2:ip"));
|
||||
assert!(response_str.contains("4:port"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_peer_priority_ordering() {
|
||||
let params = make_test_params(true);
|
||||
let topology = make_test_topology();
|
||||
|
||||
let response = handle_announce(¶ms, &topology).unwrap();
|
||||
|
||||
// In compact format, first peer should be node1 (priority 10)
|
||||
// which is 192.168.1.1:6881
|
||||
// Look for the peer data after "5:peers12:"
|
||||
let peers_marker = b"5:peers12:";
|
||||
let pos = response
|
||||
.windows(peers_marker.len())
|
||||
.position(|w| w == peers_marker)
|
||||
.unwrap();
|
||||
let peer_data = &response[pos + peers_marker.len()..pos + peers_marker.len() + 6];
|
||||
|
||||
// First peer should be 192.168.1.1 (node1 with higher priority)
|
||||
assert_eq!(&peer_data[0..4], &[192, 168, 1, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_topology() {
|
||||
let params = make_test_params(true);
|
||||
let topology = TopologyData { peers: vec![] };
|
||||
|
||||
let response = handle_announce(¶ms, &topology).unwrap();
|
||||
|
||||
// Should still be valid bencoded response with empty peers
|
||||
assert!(response.starts_with(b"d"));
|
||||
assert!(response.ends_with(b"e"));
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi884e4:pathl9:README.mdeed6:lengthi1249e4:pathl19:chat_template.jinjaeed6:lengthi1848e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi52e4:pathl22:generation_config.jsoneed6:lengthi221164e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi11330e4:pathl20:tokenization_kimi.pyeed6:lengthi2738e4:pathl21:tokenizer_config.jsoneee4:name40:91fb4f9fd1de100104925196d62b8ee06fd2ad6012:piece lengthi262144e6:pieces40:<3A>C<EFBFBD>t:<3A><>I_<49>i*xg<78><04>s|,<2C>4S<34><53><EFBFBD>j<EFBFBD><6A><EFBFBD>S<EFBFBD><03>|d<>e8:url-list63:https://huggingface.co/mlx-community/Kimi-K2-Instruct-4bit/raw/e
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi864e4:pathl9:README.mdeed6:lengthi3442e4:pathl19:chat_template.jinjaeed6:lengthi3445e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi53e4:pathl22:generation_config.jsoneed6:lengthi129766e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi12597e4:pathl20:tokenization_kimi.pyeed6:lengthi4047e4:pathl21:tokenizer_config.jsoneee4:name40:035a0cdd221ae0dca6b03120e20704a251a7bc9b12:piece lengthi262144e6:pieces20:<3A>^<5E>9`<60>C<18><>Y<EFBFBD>-L<><4C>*EC*e8:url-list58:https://huggingface.co/mlx-community/Kimi-K2-Thinking/raw/e
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:de2dfaf56839b7d0e834157d2401dee02726874d12:piece lengthi262144e6:pieces20:<3A>*_<1F><><EFBFBD><18>Tij<04><>+<2B>]<5D><>e8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-4bit/raw/e
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:c5bfd839cd4cda0e5a39a97e00218d9c56e468af12:piece lengthi262144e6:pieces20:܌!<0E><><EFBFBD>TO<54><4F>4<><34><EFBFBD>P<EFBFBD>_Qe8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-8bit/raw/e
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi1033e4:pathl9:README.mdeed6:lengthi707e4:pathl17:added_tokens.jsoneed6:lengthi6722e4:pathl19:chat_template.jinjaeed6:lengthi1222e4:pathl11:config.jsoneed6:lengthi180e4:pathl22:generation_config.jsoneed6:lengthi1671853e4:pathl10:merges.txteed6:lengthi154390e4:pathl28:model.safetensors.index.jsoneed6:lengthi28881e4:pathl24:qwen3_xml_tool_parser.pyeed6:lengthi613e4:pathl23:special_tokens_map.jsoneed6:lengthi5405e4:pathl21:tokenizer_config.jsoneed6:lengthi2776833e4:pathl10:vocab.jsoneee4:name40:ca8dbf41071f579fbe3260f20bbe1ab896f7903112:piece lengthi262144e6:pieces360:<3A>3<EFBFBD>\<5C>PDE<44><45><17><><EFBFBD><06><06><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>c+<2B>h{"<0B><>_
|
||||
m<EFBFBD> 7<><37><EFBFBD><EFBFBD>.<2E>h<14>:٣<>fm<66><6D>,<2C>w<EFBFBD><77>nOМ<4F><11><>"<22><><EFBFBD><EFBFBD>&j<><6A>_<EFBFBD><5F>"F<><46><EFBFBD>u<18>gU<67><08><><EFBFBD>QW<51><57><EFBFBD><EFBFBD>@qiiq<69><71>T<EFBFBD><54><EFBFBD>P<>lSJƤ<4A>\<5C><><EFBFBD>R!<21>=<3D><>v<EFBFBD><76><EFBFBD>F<EFBFBD>q9<71><39><EFBFBD><EFBFBD><01><><EFBFBD><EFBFBD><av<61>B@<40><> <09>z
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi75789955e4:pathl17:model.safetensorseee4:name40:f56bc6adfb74c794203dc8ca94e0bccfe2bcd6cc12:piece lengthi16777216e6:pieces100:QM0Ts@Ev<>XԄ=<3D>6_xhњU4=<3D><>7<EFBFBD>j<EFBFBD><6A><EFBFBD><18>F<EFBFBD>M<EFBFBD>q<EFBFBD><71><EFBFBD><EFBFBD>m>a<><61>H°*'<27>5<EFBFBD><35>/9B<39><42>^V<>4H9m<39><6D><EFBFBD><EFBFBD>0<EFBFBD>^z<><7A>+YS*<2A>M<EFBFBD><4D>G<EFBFBD>+<2B>.<02>h<EFBFBD>5e8:url-list62:https://huggingface.co/mlx-community/SmolLM-135M-4bit/resolve/e
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi845e4:pathl9:README.mdeed6:lengthi16738e4:pathl19:chat_template.jinjaeed6:lengthi50145e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi100431e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi4200e4:pathl21:tokenizer_config.jsoneee4:name40:81e5ac3ad0af6efb1298a8e8c7a10ed2990c137b12:piece lengthi262144e6:pieces20:ME<4D>TVE@ͯ<>N<4E>8<><38><EFBFBD>`e8:url-list63:https://huggingface.co/mlx-community/gpt-oss-120b-MXFP4-Q8/raw/e
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi838e4:pathl9:README.mdeed6:lengthi33998e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi67046e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi21694e4:pathl21:tokenizer_config.jsoneee4:name40:f356f2747216d7e98fee755df25987459fc1908912:piece lengthi262144e6:pieces20:<3A><><EFBFBD><EFBFBD>ͥ<><CDA5><EFBFBD>g#`<60><>f<EFBFBD>x<EFBFBD><78>e8:url-list62:https://huggingface.co/mlx-community/gpt-oss-20b-MXFP4-Q4/raw/e
|
||||
Binary file not shown.
@@ -0,0 +1,6 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi956e4:pathl9:README.mdeed6:lengthi207e4:pathl17:added_tokens.jsoneed6:lengthi848e4:pathl11:config.jsoneed6:lengthi441810e4:pathl10:merges.txteed6:lengthi25864e4:pathl28:model.safetensors.index.jsoneed6:lengthi801e4:pathl23:special_tokens_map.jsoneed6:lengthi3476578e4:pathl14:tokenizer.jsoneed6:lengthi9935e4:pathl21:tokenizer_config.jsoneed6:lengthi776995e4:pathl10:vocab.jsoneee4:name40:39b35eaa97282c34db81f61a983b4b83344e10f112:piece lengthi262144e6:pieces380:<3A>ih֨
|
||||
[-<2D><><EFBFBD>}<7D><19><>U<EFBFBD>){[<5B>+<2B>7PU<><13>nR`<60><>g<EFBFBD><0C>vH<76>x<78>q<EFBFBD><71>Lz<4C>^џ<>Q@<>Ĉ|Š<><C5A0><EFBFBD><EFBFBD>\<5C><>ehۢ<68>S<EFBFBD><0B>#<23>g<EFBFBD>Y%@D:ҩ<><D2A9><EFBFBD>}ޥXO<><4F><EFBFBD><EFBFBD><EFBFBD><06><0C><><EFBFBD><EFBFBD><EFBFBD><1B>Y<EFBFBD>"<22><>|<7C>JH<4A><0C>w<EFBFBD><05>MH<4D>*k<>@R<><52>1i<31>|<7C>y<H<02><>H{<7B><14>
|
||||
<EFBFBD><1B><P<><50>@<40><><16><>E<<3C><><EFBFBD>S<EFBFBD><53>|<7C><><EFBFBD>A
|
||||
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>_<><5F>
|
||||
;<3B>Rg<52><67><EFBFBD>?Ĩ<>$<24><><EFBFBD>|@`X<58><7F>#<23><><EFBFBD>M<EFBFBD>$n-<2D><10><>i<EFBFBD>
|
||||
9<>6ɝ@t<><74>j<EFBFBD><16>n<EFBFBD><6E><EFBFBD><EFBFBD>ɃH<C983><48>,<2C><>
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16447e4:pathl9:README.mdeed6:lengthi970e4:pathl11:config.jsoneed6:lengthi62518e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55421e4:pathl21:tokenizer_config.jsoneee4:name40:8103891b028a8933068e47751bc2acc10bb59aa212:piece lengthi262144e6:pieces20:<3A>l<EFBFBD>f<EFBFBD>7<>.<2E><><EFBFBD><0B> a<><61>e8:url-list69:https://huggingface.co/mlx-community/llama-3.3-70b-instruct-fp16/raw/e
|
||||
@@ -23,6 +23,7 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking = { workspace = true }
|
||||
downloads = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
|
||||
334
rust/exo_pyo3_bindings/src/downloads.rs
Normal file
334
rust/exo_pyo3_bindings/src/downloads.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
//! Downloads module - BitTorrent downloads PyO3 bindings
|
||||
|
||||
use crate::ext::*;
|
||||
use downloads::bencode::AnnounceParams;
|
||||
use downloads::tracker::{PeerInfo, TopologyData, handle_announce as rust_handle_announce};
|
||||
use downloads::{DownloadProgress, TorrentSession};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBytes, PyDict};
|
||||
use std::net::Ipv4Addr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Handle a tracker announce request
|
||||
///
|
||||
/// Args:
|
||||
/// params: Dictionary with announce parameters (info_hash, peer_id, port, etc.)
|
||||
/// peers: List of peer dictionaries (node_id, ip, port, has_complete, priority)
|
||||
///
|
||||
/// Returns:
|
||||
/// Bencoded announce response as bytes
|
||||
#[pyfunction]
|
||||
fn handle_tracker_announce(
|
||||
py: Python<'_>,
|
||||
params: &Bound<'_, PyDict>,
|
||||
peers: &Bound<'_, pyo3::types::PyList>,
|
||||
) -> PyResult<Py<PyBytes>> {
|
||||
// Parse announce params
|
||||
let info_hash = {
|
||||
let info_hash_item = params
|
||||
.get_item("info_hash")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing info_hash"))?;
|
||||
let info_hash_bytes: &[u8] = info_hash_item.extract()?;
|
||||
|
||||
if info_hash_bytes.len() != 20 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"info_hash must be 20 bytes",
|
||||
));
|
||||
}
|
||||
|
||||
let mut info_hash = [0u8; 20];
|
||||
info_hash.copy_from_slice(info_hash_bytes);
|
||||
info_hash
|
||||
};
|
||||
|
||||
let peer_id = {
|
||||
let peer_id_item = params
|
||||
.get_item("peer_id")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing peer_id"))?;
|
||||
let peer_id_bytes: &[u8] = peer_id_item.extract()?;
|
||||
|
||||
if peer_id_bytes.len() != 20 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"peer_id must be 20 bytes",
|
||||
));
|
||||
}
|
||||
|
||||
let mut peer_id = [0u8; 20];
|
||||
peer_id.copy_from_slice(peer_id_bytes);
|
||||
peer_id
|
||||
};
|
||||
|
||||
let port: u16 = params
|
||||
.get_item("port")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
|
||||
.extract()?;
|
||||
|
||||
let uploaded: u64 = params
|
||||
.get_item("uploaded")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing uploaded"))?
|
||||
.extract()?;
|
||||
|
||||
let downloaded: u64 = params
|
||||
.get_item("downloaded")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing downloaded"))?
|
||||
.extract()?;
|
||||
|
||||
let left: u64 = params
|
||||
.get_item("left")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing left"))?
|
||||
.extract()?;
|
||||
|
||||
let compact: bool = params
|
||||
.get_item("compact")?
|
||||
.map(|v| v.extract().unwrap_or(true))
|
||||
.unwrap_or(true);
|
||||
|
||||
let announce_params = AnnounceParams {
|
||||
info_hash,
|
||||
peer_id,
|
||||
port,
|
||||
uploaded,
|
||||
downloaded,
|
||||
left,
|
||||
compact,
|
||||
event: None, // TODO: parse event if needed
|
||||
};
|
||||
|
||||
// Parse peer list
|
||||
let peer_infos: Result<Vec<PeerInfo>, PyErr> = peers
|
||||
.iter()
|
||||
.map(|peer_item| {
|
||||
let peer_dict: &Bound<'_, PyDict> = peer_item.downcast()?;
|
||||
let node_id: String = peer_dict
|
||||
.get_item("node_id")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing node_id"))?
|
||||
.extract()?;
|
||||
|
||||
let ip_str: String = peer_dict
|
||||
.get_item("ip")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing ip"))?
|
||||
.extract()?;
|
||||
|
||||
let ip: Ipv4Addr = ip_str
|
||||
.parse()
|
||||
.map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid IP address"))?;
|
||||
|
||||
let port: u16 = peer_dict
|
||||
.get_item("port")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
|
||||
.extract()?;
|
||||
|
||||
let has_complete: bool = peer_dict
|
||||
.get_item("has_complete")?
|
||||
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
let priority: i32 = peer_dict
|
||||
.get_item("priority")?
|
||||
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(0))
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(PeerInfo {
|
||||
node_id,
|
||||
ip,
|
||||
port,
|
||||
has_complete,
|
||||
priority,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let peer_infos = peer_infos?;
|
||||
|
||||
let topology = TopologyData { peers: peer_infos };
|
||||
|
||||
// Call Rust tracker handler
|
||||
let response_bytes = rust_handle_announce(&announce_params, &topology).pyerr()?;
|
||||
|
||||
// Return as Python bytes
|
||||
Ok(PyBytes::new(py, &response_bytes).unbind())
|
||||
}
|
||||
|
||||
/// Get all embedded torrent variants for a model
|
||||
///
|
||||
/// Args:
|
||||
/// model_id: Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
|
||||
/// revision: Git commit hash
|
||||
///
|
||||
/// Returns:
|
||||
/// List of (variant_name, torrent_data) tuples, e.g., [("small", bytes), ("large", bytes)]
|
||||
/// Returns empty list if no torrents found.
|
||||
#[pyfunction]
|
||||
fn get_embedded_torrents(
|
||||
py: Python<'_>,
|
||||
model_id: String,
|
||||
revision: String,
|
||||
) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
let torrents = downloads::get_embedded_torrents(&model_id, &revision);
|
||||
Ok(torrents
|
||||
.into_iter()
|
||||
.map(|(variant, data)| (variant, PyBytes::new(py, &data).unbind()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get file list from torrent data
|
||||
///
|
||||
/// Args:
|
||||
/// torrent_data: Raw .torrent file contents
|
||||
///
|
||||
/// Returns:
|
||||
/// List of (index, path, size_bytes) tuples for each file in the torrent
|
||||
#[pyfunction]
|
||||
fn get_torrent_file_list(torrent_data: Vec<u8>) -> PyResult<Vec<(usize, String, u64)>> {
|
||||
let files = downloads::get_torrent_file_list(&torrent_data).pyerr()?;
|
||||
Ok(files
|
||||
.into_iter()
|
||||
.map(|f| (f.index, f.path, f.size))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Python wrapper for TorrentSession
|
||||
#[pyclass]
|
||||
struct TorrentSessionHandle {
|
||||
session: Arc<Mutex<TorrentSession>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl TorrentSessionHandle {
|
||||
/// Create a new torrent session
|
||||
///
|
||||
/// Args:
|
||||
/// session_dir: Directory to store session state and downloads
|
||||
#[new]
|
||||
fn new(session_dir: String) -> PyResult<Self> {
|
||||
let session_path = PathBuf::from(session_dir);
|
||||
|
||||
let session = tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { TorrentSession::new(session_path).await })
|
||||
.pyerr()?;
|
||||
|
||||
Ok(Self {
|
||||
session: Arc::new(Mutex::new(session)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a torrent from bytes
|
||||
///
|
||||
/// Args:
|
||||
/// torrent_data: Raw .torrent file contents
|
||||
/// save_path: Where to save downloaded files
|
||||
/// file_indices: Optional list of file indices to download
|
||||
///
|
||||
/// Returns:
|
||||
/// Info hash as hex string
|
||||
fn add_torrent(
|
||||
&self,
|
||||
_py: Python<'_>,
|
||||
torrent_data: Vec<u8>,
|
||||
save_path: String,
|
||||
file_indices: Option<Vec<usize>>,
|
||||
) -> PyResult<String> {
|
||||
let session = Arc::clone(&self.session);
|
||||
let save_path = PathBuf::from(save_path);
|
||||
|
||||
tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async {
|
||||
session
|
||||
.lock()
|
||||
.await
|
||||
.add_torrent(torrent_data, save_path, file_indices)
|
||||
.await
|
||||
})
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Get download progress for a torrent
|
||||
///
|
||||
/// Args:
|
||||
/// info_hash: Torrent info hash
|
||||
///
|
||||
/// Returns:
|
||||
/// Dictionary with progress information
|
||||
fn get_progress(&self, py: Python<'_>, info_hash: String) -> PyResult<Py<PyDict>> {
|
||||
let session = Arc::clone(&self.session);
|
||||
|
||||
let progress: DownloadProgress = tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { session.lock().await.get_progress(&info_hash).await })
|
||||
.pyerr()?;
|
||||
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("downloaded_bytes", progress.downloaded_bytes)?;
|
||||
dict.set_item("total_bytes", progress.total_bytes)?;
|
||||
dict.set_item("download_speed", progress.download_speed)?;
|
||||
dict.set_item("upload_speed", progress.upload_speed)?;
|
||||
dict.set_item("peers_connected", progress.peers_connected)?;
|
||||
dict.set_item("is_finished", progress.is_finished)?;
|
||||
|
||||
Ok(dict.unbind())
|
||||
}
|
||||
|
||||
/// Wait until torrent download is completed
|
||||
///
|
||||
/// Args:
|
||||
/// info_hash: Torrent info hash
|
||||
fn wait_until_completed(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
|
||||
let session = Arc::clone(&self.session);
|
||||
|
||||
tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { session.lock().await.wait_until_completed(&info_hash).await })
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Enable seeding for a torrent
|
||||
///
|
||||
/// Args:
|
||||
/// info_hash: Torrent info hash
|
||||
fn enable_seeding(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
|
||||
let session = Arc::clone(&self.session);
|
||||
|
||||
tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { session.lock().await.enable_seeding(&info_hash).await })
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Remove a torrent from the session
|
||||
///
|
||||
/// Args:
|
||||
/// info_hash: Torrent info hash
|
||||
fn remove_torrent(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
|
||||
let session = Arc::clone(&self.session);
|
||||
|
||||
tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { session.lock().await.remove_torrent(&info_hash).await })
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// List all torrents in the session
|
||||
///
|
||||
/// Returns:
|
||||
/// List of info hashes
|
||||
fn list_torrents(&self, _py: Python<'_>) -> PyResult<Vec<String>> {
|
||||
let session = Arc::clone(&self.session);
|
||||
|
||||
tokio::runtime::Runtime::new()
|
||||
.pyerr()?
|
||||
.block_on(async { Ok(session.lock().await.list_torrents().await) })
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads submodule
|
||||
pub(crate) fn downloads_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(handle_tracker_announce, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_embedded_torrents, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_torrent_file_list, m)?)?;
|
||||
m.add_class::<TorrentSessionHandle>()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -17,10 +17,12 @@
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
pub(crate) mod downloads;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::downloads::downloads_submodule;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
@@ -207,6 +209,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
downloads_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
58
scripts/mktorrent.sh
Executable file
58
scripts/mktorrent.sh
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env nix-shell
|
||||
#!nix-shell -i bash -p mktorrent -p python3Packages.huggingface-hub -p git -p git-lfs
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
MODEL="$1"
|
||||
|
||||
mkdir -p "$MODEL"
|
||||
|
||||
# Step 1: Clone/fetch the repo and get the hash of head
|
||||
mkdir -p "$MODEL"
|
||||
if test -d "$MODEL/git"; then
|
||||
# Assert that the origin is correct
|
||||
git -C "$MODEL/git" fetch
|
||||
else
|
||||
git clone "https://huggingface.co/$MODEL" "$MODEL/git"
|
||||
fi
|
||||
|
||||
HASH=$(git -C "$MODEL/git" rev-parse origin/main)
|
||||
LARGE_FILES=$(git -C "$MODEL/git" lfs ls-files --all --name-only)
|
||||
|
||||
SMALL_DIR="$MODEL/$HASH-small"
|
||||
LARGE_DIR="$MODEL/$HASH-large"
|
||||
mkdir -p "$SMALL_DIR" "$LARGE_DIR"
|
||||
|
||||
# Step 2: Prepare files. Two torrents: one for large files and one for metadata.
|
||||
git -C "$MODEL/git" archive "$HASH" | tar -x -C "$SMALL_DIR"
|
||||
echo "$LARGE_FILES" | xargs -I{} rm "$SMALL_DIR/{}"
|
||||
|
||||
echo "$LARGE_FILES" | xargs hf download "$MODEL" --revision "$HASH" --local-dir "$LARGE_DIR" --cache-dir "$(realpath .cache)" --include
|
||||
if test -d "$LARGE_DIR/.cache"; then
|
||||
echo ".cache created against our wishes, deleting it..."
|
||||
rm -r "$LARGE_DIR/.cache"
|
||||
fi
|
||||
|
||||
# Step 3: Create both torrents
|
||||
mkdir -p "torrents/$MODEL/"
|
||||
SMALL_TORRENT_PATH="torrents/$MODEL/${HASH}.small.torrent"
|
||||
LARGE_TORRENT_PATH="torrents/$MODEL/${HASH}.large.torrent"
|
||||
|
||||
mktorrent "$SMALL_DIR/" --output="$SMALL_TORRENT_PATH" \
|
||||
-n "$HASH" \
|
||||
--web-seed="https://huggingface.co/$MODEL/raw/" \
|
||||
--no-date \
|
||||
--announce="udp://tracker.opentrackr.org:1337/announce"
|
||||
# --private
|
||||
|
||||
mktorrent "$LARGE_DIR/" --output="$LARGE_TORRENT_PATH" \
|
||||
-n "$HASH" \
|
||||
--web-seed="https://huggingface.co/$MODEL/resolve/" \
|
||||
--piece-length=24 \
|
||||
--no-date \
|
||||
--announce="udp://tracker.opentrackr.org:1337/announce"
|
||||
# --private
|
||||
|
||||
echo "Successfully created torrent files in:"
|
||||
echo "$SMALL_TORRENT_PATH"
|
||||
echo "$LARGE_TORRENT_PATH"
|
||||
@@ -35,6 +35,7 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
enable_torrents: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -66,7 +67,8 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
exo_shard_downloader(enable_torrents=args.enable_torrents),
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
@@ -74,7 +76,6 @@ class Node:
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -98,7 +99,7 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
return cls(router, worker, election, er_recv, master, api, node_id, args.enable_torrents)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
@@ -175,7 +176,7 @@ class Node:
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
exo_shard_downloader(enable_torrents=self.enable_torrents),
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
@@ -215,6 +216,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
enable_torrents: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -256,6 +258,12 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torrents",
|
||||
action="store_true",
|
||||
dest="enable_torrents",
|
||||
help="Enable BitTorrent-based downloads (experimental)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import cast
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
@@ -178,6 +178,7 @@ class API:
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.get("/_internal/announce")(self.tracker_announce)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -622,6 +623,65 @@ class API:
|
||||
]
|
||||
)
|
||||
|
||||
async def tracker_announce(self, request: Request) -> Response:
|
||||
"""BitTorrent tracker announce endpoint for private tracker."""
|
||||
try:
|
||||
from exo_pyo3_bindings import handle_tracker_announce # type: ignore
|
||||
except ImportError as e:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Torrent support not available (exo_pyo3_bindings not installed)",
|
||||
) from e
|
||||
|
||||
# Parse announce parameters from query string
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Extract required parameters
|
||||
try:
|
||||
info_hash_hex = query_params.get("info_hash", "")
|
||||
peer_id_hex = query_params.get("peer_id", "")
|
||||
|
||||
# URL decode and convert to bytes
|
||||
info_hash = bytes.fromhex(info_hash_hex) if info_hash_hex else b""
|
||||
peer_id = bytes.fromhex(peer_id_hex) if peer_id_hex else b""
|
||||
|
||||
if len(info_hash) != 20 or len(peer_id) != 20:
|
||||
raise ValueError("info_hash and peer_id must be 20 bytes")
|
||||
|
||||
params = {
|
||||
"info_hash": info_hash,
|
||||
"peer_id": peer_id,
|
||||
"port": int(query_params.get("port", "6881")),
|
||||
"uploaded": int(query_params.get("uploaded", "0")),
|
||||
"downloaded": int(query_params.get("downloaded", "0")),
|
||||
"left": int(query_params.get("left", "0")),
|
||||
"compact": query_params.get("compact", "1") == "1",
|
||||
}
|
||||
except (ValueError, KeyError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid announce parameters: {e}",
|
||||
) from e
|
||||
|
||||
# Build peer list from topology
|
||||
# TODO: Implement _build_peer_list_from_topology() to extract peers from self.state.topology
|
||||
peers = [] # For now, return empty peer list
|
||||
|
||||
# Call Rust tracker handler
|
||||
try:
|
||||
response_bytes: bytes = handle_tracker_announce(params, peers) # type: ignore
|
||||
return Response(
|
||||
content=response_bytes,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tracker announce error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Tracker announce failed: {e}",
|
||||
) from e
|
||||
|
||||
async def run(self):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
|
||||
@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -16,3 +16,4 @@ class ModelMetadata(CamelCaseModel):
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
revision: str | None = None # Git commit hash for torrent lookup
|
||||
|
||||
@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
pass
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
|
||||
@@ -12,10 +12,17 @@ from exo.worker.download.download_utils import RepoDownloadProgress, download_sh
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
return SingletonShardDownloader(
|
||||
CachedShardDownloader(ResumableShardDownloader(max_parallel_downloads))
|
||||
)
|
||||
def exo_shard_downloader(
|
||||
max_parallel_downloads: int = 8, enable_torrents: bool = False
|
||||
) -> ShardDownloader:
|
||||
if enable_torrents:
|
||||
from exo.worker.download.torrent_downloader import TorrentShardDownloader
|
||||
|
||||
base = TorrentShardDownloader(max_parallel_downloads)
|
||||
else:
|
||||
base = ResumableShardDownloader(max_parallel_downloads)
|
||||
|
||||
return SingletonShardDownloader(CachedShardDownloader(base))
|
||||
|
||||
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
|
||||
366
src/exo/worker/download/torrent_downloader.py
Normal file
366
src/exo/worker/download/torrent_downloader.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Torrent-based shard downloader using BitTorrent protocol with private tracker.
|
||||
|
||||
This module implements downloading model shards via BitTorrent with:
|
||||
- Private tracker integration (/_internal/announce endpoint)
|
||||
- WebSeed (BEP 19) fallback to HTTPS
|
||||
- Selective file download based on shard layer ranges
|
||||
- Persistent seeding of downloaded content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import timedelta
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, resolve_allow_patterns
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
class TorrentShardDownloader(ShardDownloader):
|
||||
"""Download model shards using BitTorrent with private tracker."""
|
||||
|
||||
def __init__(self, max_parallel_downloads: int = 8):
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self._progress_callbacks: list[
|
||||
Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
] = []
|
||||
|
||||
# Initialize TorrentSessionHandle
|
||||
try:
|
||||
from exo_pyo3_bindings import TorrentSessionHandle # type: ignore
|
||||
|
||||
session_dir = EXO_MODELS_DIR / "v2" / ".torrent_session"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.session = TorrentSessionHandle(str(session_dir))
|
||||
except ImportError:
|
||||
logger.error("exo_pyo3_bindings not available, torrent downloads disabled")
|
||||
self.session = None
|
||||
|
||||
def on_progress(
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
"""Register a progress callback."""
|
||||
self._progress_callbacks.append(callback)
|
||||
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
) -> Path:
|
||||
"""Download a model shard using BitTorrent.
|
||||
|
||||
Downloads all torrent variants for the model and applies file filtering
|
||||
based on the shard's layer range.
|
||||
|
||||
Args:
|
||||
shard: Shard metadata including model ID and layer range
|
||||
config_only: If True, only download config files (not implemented for torrents yet)
|
||||
|
||||
Returns:
|
||||
Path to the downloaded shard directory
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torrent file is not found or session not initialized
|
||||
"""
|
||||
if self.session is None:
|
||||
raise RuntimeError(
|
||||
"TorrentSessionHandle not initialized. "
|
||||
"exo_pyo3_bindings module not available."
|
||||
)
|
||||
|
||||
model_id = str(shard.model_meta.model_id)
|
||||
|
||||
# Resolve "main" branch to commit hash
|
||||
revision = await self._resolve_revision(model_id, "main")
|
||||
|
||||
# Load all embedded torrent variants
|
||||
variants = self._load_embedded_torrents(model_id, revision)
|
||||
if not variants:
|
||||
raise RuntimeError(
|
||||
f"No torrents found for {model_id}@{revision}. "
|
||||
f"Expected at: rust/downloads/torrents/{model_id}/{revision}.*.torrent. "
|
||||
f"Please add torrent files or use HTTP downloads (--enable-torrents=False)."
|
||||
)
|
||||
|
||||
# Build v2 path: ~/.exo/models/v2/{model_id}/{revision}
|
||||
save_path = EXO_MODELS_DIR / "v2" / model_id / revision
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get allow patterns for file filtering (same as HTTP downloads)
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
logger.debug(f"Torrent download allow patterns: {allow_patterns}")
|
||||
|
||||
# Track aggregate progress across all variants
|
||||
total_downloaded = 0
|
||||
total_size = 0
|
||||
|
||||
# Download each variant
|
||||
for variant_name, torrent_data in variants:
|
||||
# Get file list from torrent
|
||||
file_list = self._get_torrent_file_list(torrent_data)
|
||||
if not file_list:
|
||||
logger.warning(f"No files in variant '{variant_name}', skipping")
|
||||
continue
|
||||
|
||||
# Filter files using allow patterns
|
||||
file_indices = self._filter_files(file_list, allow_patterns)
|
||||
if not file_indices:
|
||||
logger.debug(f"No matching files in variant '{variant_name}' after filtering")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Adding torrent variant '{variant_name}' for {model_id}@{revision} "
|
||||
f"({len(file_indices)} of {len(file_list)} files)"
|
||||
)
|
||||
|
||||
try:
|
||||
info_hash = self.session.add_torrent(
|
||||
torrent_data, str(save_path), file_indices
|
||||
)
|
||||
except Exception as e:
|
||||
# Torrent may already be added
|
||||
logger.debug(f"Could not add variant '{variant_name}': {e}")
|
||||
continue
|
||||
|
||||
# Wait for download with progress reporting
|
||||
logger.info(f"Starting download for variant '{variant_name}' ({info_hash})")
|
||||
while True:
|
||||
progress_dict = self.session.get_progress(info_hash)
|
||||
|
||||
# Update aggregate progress
|
||||
variant_downloaded = progress_dict["downloaded_bytes"]
|
||||
variant_total = progress_dict["total_bytes"]
|
||||
download_speed = progress_dict["download_speed"] # bytes/sec from rqbit
|
||||
|
||||
# Calculate ETA based on remaining bytes and current speed
|
||||
current_downloaded = total_downloaded + variant_downloaded
|
||||
current_total = total_size + variant_total
|
||||
remaining_bytes = current_total - current_downloaded
|
||||
eta = (
|
||||
timedelta(seconds=remaining_bytes / download_speed)
|
||||
if download_speed > 0
|
||||
else timedelta(0)
|
||||
)
|
||||
|
||||
# Report combined progress
|
||||
progress = self._make_progress(
|
||||
shard=shard,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
downloaded=current_downloaded,
|
||||
total=current_total,
|
||||
status="in_progress",
|
||||
speed=download_speed,
|
||||
eta=eta,
|
||||
)
|
||||
self._report_progress(shard, progress)
|
||||
|
||||
if progress_dict["is_finished"]:
|
||||
logger.info(f"Variant '{variant_name}' download completed")
|
||||
total_downloaded += variant_total
|
||||
total_size += variant_total
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Enable persistent seeding for this variant
|
||||
try:
|
||||
self.session.enable_seeding(info_hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enable seeding for '{variant_name}': {e}")
|
||||
|
||||
return save_path
|
||||
|
||||
async def get_shard_download_status(
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
"""Get download status for all shards."""
|
||||
if self.session is None:
|
||||
return
|
||||
|
||||
# List all torrents in the session
|
||||
info_hashes = self.session.list_torrents()
|
||||
|
||||
# We don't have a straightforward way to map info_hash back to shard/path
|
||||
# This would require tracking in the session or metadata
|
||||
# For now, this method doesn't yield anything meaningful for torrents
|
||||
for info_hash in info_hashes:
|
||||
try:
|
||||
_progress_dict = self.session.get_progress(info_hash)
|
||||
# Can't create proper progress without shard context
|
||||
# Skip yielding until we have proper tracking
|
||||
logger.debug(f"Torrent {info_hash} status check - not implemented")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting status for {info_hash}: {e}")
|
||||
# Yield nothing - torrents don't support this method well yet
|
||||
return
|
||||
yield # Make this an async generator
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
"""Get download status for a specific shard."""
|
||||
model_id = str(shard.model_meta.model_id)
|
||||
|
||||
if self.session is None:
|
||||
return self._make_progress(
|
||||
shard=shard,
|
||||
model_id=model_id,
|
||||
revision="unknown",
|
||||
downloaded=0,
|
||||
total=0,
|
||||
status="not_started",
|
||||
)
|
||||
|
||||
# We would need to track info_hash -> shard mapping
|
||||
# For now, return empty progress
|
||||
return self._make_progress(
|
||||
shard=shard,
|
||||
model_id=model_id,
|
||||
revision="unknown",
|
||||
downloaded=0,
|
||||
total=0,
|
||||
status="not_started",
|
||||
)
|
||||
|
||||
async def _resolve_revision(self, model_id: str, branch: str = "main") -> str:
|
||||
"""Resolve branch name to commit hash using HuggingFace API.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
|
||||
branch: Branch name (default: "main")
|
||||
|
||||
Returns:
|
||||
Git commit hash (SHA)
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import model_info
|
||||
|
||||
info = model_info(model_id, revision=branch)
|
||||
return str(info.sha)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve revision for {model_id}@{branch}: {e}")
|
||||
# Fallback to "main" as revision if API call fails
|
||||
logger.warning(f"Using branch name '{branch}' as revision")
|
||||
return branch
|
||||
|
||||
def _load_embedded_torrents(
|
||||
self, model_id: str, revision: str
|
||||
) -> list[tuple[str, bytes]]:
|
||||
"""Load all embedded torrent variants from Rust binary.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
revision: Git commit hash
|
||||
|
||||
Returns:
|
||||
List of (variant_name, torrent_bytes) tuples, e.g., [("small", ...), ("large", ...)]
|
||||
"""
|
||||
try:
|
||||
from exo_pyo3_bindings import get_embedded_torrents # type: ignore
|
||||
|
||||
result: list[tuple[str, bytes]] = get_embedded_torrents(model_id, revision)
|
||||
return result
|
||||
except ImportError:
|
||||
logger.warning("exo_pyo3_bindings not available, cannot load torrents")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading torrents for {model_id}@{revision}: {e}")
|
||||
return []
|
||||
|
||||
def _get_torrent_file_list(
|
||||
self, torrent_data: bytes
|
||||
) -> list[tuple[int, str, int]]:
|
||||
"""Get file list from torrent data.
|
||||
|
||||
Args:
|
||||
torrent_data: Raw torrent file bytes
|
||||
|
||||
Returns:
|
||||
List of (index, path, size_bytes) tuples
|
||||
"""
|
||||
try:
|
||||
from exo_pyo3_bindings import get_torrent_file_list # type: ignore
|
||||
|
||||
return get_torrent_file_list(torrent_data)
|
||||
except ImportError:
|
||||
logger.warning("exo_pyo3_bindings not available")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing torrent file list: {e}")
|
||||
return []
|
||||
|
||||
def _filter_files(
|
||||
self, file_list: list[tuple[int, str, int]], allow_patterns: list[str]
|
||||
) -> list[int]:
|
||||
"""Filter file list using glob patterns.
|
||||
|
||||
Args:
|
||||
file_list: List of (index, path, size) tuples
|
||||
allow_patterns: Glob patterns to match (e.g., ["*.json", "*.safetensors"])
|
||||
|
||||
Returns:
|
||||
List of file indices that match at least one pattern
|
||||
"""
|
||||
matching_indices = []
|
||||
for index, path, _size in file_list:
|
||||
for pattern in allow_patterns:
|
||||
if fnmatch(path, pattern):
|
||||
matching_indices.append(index)
|
||||
break
|
||||
return matching_indices
|
||||
|
||||
def _make_progress(
|
||||
self,
|
||||
shard: ShardMetadata,
|
||||
model_id: str,
|
||||
revision: str,
|
||||
downloaded: int,
|
||||
total: int,
|
||||
status: str,
|
||||
speed: float = 0.0,
|
||||
eta: timedelta | None = None,
|
||||
) -> RepoDownloadProgress:
|
||||
"""Create a RepoDownloadProgress object with proper fields.
|
||||
|
||||
Args:
|
||||
shard: Shard metadata
|
||||
model_id: Model identifier
|
||||
revision: Git revision
|
||||
downloaded: Downloaded bytes
|
||||
total: Total bytes
|
||||
status: Download status ("not_started", "in_progress", "complete")
|
||||
speed: Download speed in bytes/second
|
||||
eta: Estimated time remaining
|
||||
|
||||
Returns:
|
||||
RepoDownloadProgress object
|
||||
"""
|
||||
return RepoDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision=revision,
|
||||
shard=shard,
|
||||
completed_files=1 if status == "complete" else 0,
|
||||
total_files=1,
|
||||
downloaded_bytes=Memory.from_bytes(downloaded),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(downloaded),
|
||||
total_bytes=Memory.from_bytes(total),
|
||||
overall_speed=speed,
|
||||
overall_eta=eta if eta is not None else timedelta(0),
|
||||
status=status, # type: ignore
|
||||
)
|
||||
|
||||
def _report_progress(
|
||||
self, shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
"""Report progress to all registered callbacks."""
|
||||
for callback in self._progress_callbacks:
|
||||
try:
|
||||
callback(shard, progress)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in progress callback: {e}")
|
||||
@@ -217,7 +217,9 @@ class Worker:
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
@@ -364,7 +366,11 @@ class Worker:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
@@ -457,7 +463,9 @@ class Worker:
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user