Compare commits

...

6 Commits

Author SHA1 Message Date
Jake Hillion
563e94ab6c feat(downloads): add ETA and speed tracking for torrent downloads
Use download_speed from rqbit (already exposed via PyO3 bindings) to
calculate and report real-time ETA for torrent downloads. Previously
these values were hardcoded to zero.
2026-01-12 15:12:56 +00:00
Jake Hillion
41c27832d9 feat(downloads): integrate torrent downloads with file filtering
- Add get_torrent_file_list() to parse torrent metadata and return file list
- Expose get_torrent_file_list via PyO3 bindings
- Update TorrentShardDownloader to download all torrent variants
- Apply same file filtering as HTTP downloads using resolve_allow_patterns()
- Fix RepoDownloadProgress constructor usage

When --enable-torrents is set, all embedded torrent variants (e.g., "small",
"large") are now downloaded with proper file filtering based on shard layers.
2026-01-12 15:12:56 +00:00
Jake Hillion
9c04350e55 feat(downloads): embed multiple torrent variants per model
- Add 52 torrent files (small and large variants) for all supported models
- Update get_embedded_torrents() to return Vec<(variant, data)> for all variants
- Remove old single-torrent API in favor of multi-variant API
- Update Python bindings to return list of (variant, bytes) tuples
2026-01-12 15:12:56 +00:00
Jake Hillion
1ca732da23 feat(downloads): switch to rqbit fork and implement fake tracker
- Update librqbit dependency to JakeHillion/rqbit fork (v9.0.0-beta.1)
- Update session.rs for new librqbit API:
  - SessionPersistenceConfig instead of bool
  - AddTorrentResponse enum handling
  - Updated TorrentStats field access
- Implement bencode module with AnnounceParams and BencodeValue encoding
- Implement fake tracker that generates bencoded announce responses
  from Exo topology data (PeerInfo, TopologyData, handle_announce)
- Fix PyO3 bindings for new librqbit types (remove allow_threads)
2026-01-12 15:12:56 +00:00
Jake Hillion
ba708c2ccd torrents! 2026-01-12 15:12:56 +00:00
Jake Hillion
c9d95859e8 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 12:08:46 +00:00
76 changed files with 24611 additions and 49 deletions

1605
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
resolver = "3"
members = [
"rust/networking",
"rust/downloads",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
@@ -25,6 +26,7 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
downloads = { path = "rust/downloads" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }

View File

@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

40
rust/downloads/Cargo.toml Normal file
View File

@@ -0,0 +1,40 @@
[package]
name = "downloads"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "downloads"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# macro dependencies
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-util = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
itertools = { workspace = true }
# tracing/logging
log = { workspace = true }
# BitTorrent library
librqbit = { git = "https://github.com/JakeHillion/rqbit", rev = "c4e2ecf81d03bd8acd96a0803d06a70b34d5da19" }
# Embed torrent files
include_dir = "0.7"
# Serialization
serde = { version = "1.0", features = ["derive"] }

View File

@@ -0,0 +1,162 @@
//! Bencode encoding for BitTorrent tracker responses
//!
//! Implements the subset of bencoding needed for tracker announce responses.
use std::collections::BTreeMap;
/// Parameters from a tracker announce request
#[derive(Debug, Clone)]
pub struct AnnounceParams {
/// 20-byte info hash of the torrent
pub info_hash: [u8; 20],
/// 20-byte peer ID of the client
pub peer_id: [u8; 20],
/// Port the client is listening on
pub port: u16,
/// Total bytes uploaded
pub uploaded: u64,
/// Total bytes downloaded
pub downloaded: u64,
/// Bytes remaining to download
pub left: u64,
/// Whether to return compact peer list (6 bytes per peer)
pub compact: bool,
/// Optional event (started, stopped, completed)
pub event: Option<AnnounceEvent>,
}
/// Announce event types
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AnnounceEvent {
Started,
Stopped,
Completed,
}
/// A bencoded value
#[derive(Debug, Clone)]
pub enum BencodeValue {
Integer(i64),
Bytes(Vec<u8>),
List(Vec<BencodeValue>),
Dict(BTreeMap<Vec<u8>, BencodeValue>),
}
impl BencodeValue {
/// Create a string value from a &str
#[inline]
pub fn string(s: &str) -> Self {
Self::Bytes(s.as_bytes().to_vec())
}
/// Create an integer value
#[inline]
pub fn integer(i: i64) -> Self {
Self::Integer(i)
}
/// Create an empty list
#[inline]
pub fn list() -> Self {
Self::List(Vec::new())
}
/// Create an empty dict
#[inline]
pub fn dict() -> Self {
Self::Dict(BTreeMap::new())
}
/// Add an item to a list (builder pattern)
#[inline]
pub fn push(mut self, value: BencodeValue) -> Self {
if let Self::List(ref mut list) = self {
list.push(value);
}
self
}
/// Insert a key-value pair into a dict (builder pattern)
#[inline]
pub fn insert(mut self, key: &str, value: BencodeValue) -> Self {
if let Self::Dict(ref mut dict) = self {
dict.insert(key.as_bytes().to_vec(), value);
}
self
}
/// Encode to bencoded bytes
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::new();
self.encode_into(&mut buf);
buf
}
/// Encode into an existing buffer
pub fn encode_into(&self, buf: &mut Vec<u8>) {
match self {
Self::Integer(i) => {
buf.push(b'i');
buf.extend_from_slice(i.to_string().as_bytes());
buf.push(b'e');
}
Self::Bytes(bytes) => {
buf.extend_from_slice(bytes.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(bytes);
}
Self::List(list) => {
buf.push(b'l');
for item in list {
item.encode_into(buf);
}
buf.push(b'e');
}
Self::Dict(dict) => {
buf.push(b'd');
// BTreeMap keeps keys sorted
for (key, value) in dict {
buf.extend_from_slice(key.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(key);
value.encode_into(buf);
}
buf.push(b'e');
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_integer() {
assert_eq!(BencodeValue::integer(42).encode(), b"i42e");
assert_eq!(BencodeValue::integer(-1).encode(), b"i-1e");
assert_eq!(BencodeValue::integer(0).encode(), b"i0e");
}
#[test]
fn test_encode_string() {
assert_eq!(BencodeValue::string("spam").encode(), b"4:spam");
assert_eq!(BencodeValue::string("").encode(), b"0:");
}
#[test]
fn test_encode_list() {
let list = BencodeValue::list()
.push(BencodeValue::string("spam"))
.push(BencodeValue::integer(42));
assert_eq!(list.encode(), b"l4:spami42ee");
}
#[test]
fn test_encode_dict() {
let dict = BencodeValue::dict()
.insert("bar", BencodeValue::string("spam"))
.insert("foo", BencodeValue::integer(42));
assert_eq!(dict.encode(), b"d3:bar4:spam3:fooi42ee");
}
}

View File

@@ -0,0 +1,108 @@
//! Embedded torrent file access
//!
//! Provides access to .torrent files embedded in the binary at compile time.
//! Each model/revision can have multiple torrent variants (e.g., "small", "large").
use include_dir::{Dir, include_dir};
/// Embedded torrent files directory
static TORRENTS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/torrents");
/// Get all embedded torrent variants for a model_id and revision
///
/// # Arguments
/// * `model_id` - Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
/// * `revision` - Git commit hash
///
/// # Returns
/// Vec of (variant_name, torrent_data) tuples, e.g., [("small", data), ("large", data)]
/// Returns empty Vec if no torrents found for this model/revision.
#[inline]
pub fn get_embedded_torrents(model_id: &str, revision: &str) -> Vec<(String, Vec<u8>)> {
let dir_path = format!("{model_id}");
let Some(model_dir) = TORRENTS.get_dir(&dir_path) else {
return Vec::new();
};
let mut results = Vec::new();
let prefix = format!("{revision}.");
let suffix = ".torrent";
for file in model_dir.files() {
let Some(name) = file.path().file_name().and_then(|n| n.to_str()) else {
continue;
};
// Match files like "{revision}.small.torrent" or "{revision}.large.torrent"
if name.starts_with(&prefix) && name.ends_with(suffix) {
// Extract variant: "{revision}.{variant}.torrent" -> "{variant}"
let middle = &name[prefix.len()..name.len() - suffix.len()];
// Skip plain "{revision}.torrent" files (wrong format)
if middle.is_empty() {
continue;
}
results.push((middle.to_string(), file.contents().to_vec()));
}
}
// Sort by variant name for consistent ordering
results.sort_by(|a, b| a.0.cmp(&b.0));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_embedded_torrents() {
// Test with the Qwen3 torrent we have
let result = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
assert!(!result.is_empty(), "Expected to find embedded torrents");
// Should have both small and large variants
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
assert!(
variants.contains(&"small"),
"Expected 'small' variant, got: {variants:?}"
);
assert!(
variants.contains(&"large"),
"Expected 'large' variant, got: {variants:?}"
);
// Verify data is not empty
for (variant, data) in &result {
assert!(!data.is_empty(), "Torrent data for '{variant}' should not be empty");
}
}
#[test]
fn test_missing_torrent() {
let result = get_embedded_torrents("nonexistent/model", "abc123");
assert!(result.is_empty(), "Expected empty Vec for missing torrent");
}
#[test]
fn test_variant_ordering() {
let result = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
if result.len() >= 2 {
// Verify alphabetical ordering
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
let mut sorted = variants.clone();
sorted.sort();
assert_eq!(variants, sorted, "Variants should be sorted alphabetically");
}
}
}

22
rust/downloads/src/lib.rs Normal file
View File

@@ -0,0 +1,22 @@
//! BitTorrent-based download system for model shards using rqbit
//!
//! This crate provides:
//! - Torrent session management via rqbit
//! - Embedded torrent file access
//! - Private tracker announce handling
//! - Selective file download based on shard layer ranges
#![allow(clippy::missing_inline_in_public_items)]
pub mod bencode;
pub mod embedded;
pub mod progress;
pub mod session;
pub mod torrent_files;
pub mod tracker;
pub use bencode::AnnounceParams;
pub use embedded::get_embedded_torrents;
pub use session::{DownloadProgress, TorrentSession};
pub use torrent_files::{get_torrent_file_list, TorrentFileInfo};
pub use tracker::{handle_announce, PeerInfo, TopologyData};

View File

@@ -0,0 +1,77 @@
//! Download progress tracking
//!
//! Types for tracking and reporting download progress to Python
use std::collections::HashMap;
/// Progress update for a torrent download
#[derive(Debug, Clone)]
pub struct DownloadProgress {
/// Total bytes to download
pub total_bytes: u64,
/// Bytes downloaded so far
pub downloaded_bytes: u64,
/// Number of pieces completed
pub pieces_completed: usize,
/// Total number of pieces
pub total_pieces: usize,
/// Number of peers connected
pub peers_connected: usize,
/// Download speed in bytes/second
pub speed_bytes_per_sec: f64,
/// Estimated time remaining in seconds
pub eta_seconds: Option<f64>,
/// Per-file progress
pub files: HashMap<String, FileProgress>,
}
#[derive(Debug, Clone)]
pub struct FileProgress {
/// Total file size
pub total_bytes: u64,
/// Bytes downloaded for this file
pub downloaded_bytes: u64,
/// Whether the file is complete
pub complete: bool,
}
impl DownloadProgress {
#[inline]
pub fn new(total_bytes: u64, total_pieces: usize) -> Self {
Self {
total_bytes,
downloaded_bytes: 0,
pieces_completed: 0,
total_pieces,
peers_connected: 0,
speed_bytes_per_sec: 0.0,
eta_seconds: None,
files: HashMap::new(),
}
}
#[inline]
pub fn progress_fraction(&self) -> f64 {
if self.total_bytes == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let fraction = self.downloaded_bytes as f64 / self.total_bytes as f64;
fraction
}
}
#[inline]
pub fn is_complete(&self) -> bool {
self.pieces_completed >= self.total_pieces
}
}

View File

@@ -0,0 +1,166 @@
//! Torrent session management using rqbit
//!
//! Provides a wrapper around rqbit's Session for managing torrent downloads
//! with persistent seeding and selective file downloads.
use anyhow::{Context, Result};
use librqbit::{AddTorrent, AddTorrentOptions, AddTorrentResponse, Api, ManagedTorrent, Session, SessionOptions, SessionPersistenceConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Download progress information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadProgress {
pub downloaded_bytes: u64,
pub total_bytes: u64,
pub download_speed: f64,
pub upload_speed: f64,
pub peers_connected: usize,
pub is_finished: bool,
}
/// Torrent session handle for managing multiple torrents
pub struct TorrentSession {
session: Arc<Session>,
api: Arc<Api>,
session_dir: PathBuf,
torrents: Arc<RwLock<HashMap<String, Arc<ManagedTorrent>>>>,
}
impl TorrentSession {
/// Create a new torrent session
///
/// # Arguments
/// * `session_dir` - Directory to store session state and downloaded files
pub async fn new(session_dir: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&session_dir).context("Failed to create session directory")?;
let opts = SessionOptions {
disable_dht: false,
disable_dht_persistence: false,
dht_config: None,
persistence: Some(SessionPersistenceConfig::Json { folder: None }),
fastresume: true,
..Default::default()
};
let session = Session::new_with_opts(session_dir.clone(), opts)
.await
.context("Failed to create rqbit session")?;
let api = Api::new(Arc::clone(&session), None);
Ok(Self {
session,
api: Arc::new(api),
session_dir,
torrents: Arc::new(RwLock::new(HashMap::new())),
})
}
/// Add a torrent from raw bytes
///
/// # Arguments
/// * `torrent_data` - Raw .torrent file contents
/// * `save_path` - Where to save the downloaded files
/// * `file_indices` - Optional list of file indices to download (None = all files)
///
/// # Returns
/// Info hash as hex string
pub async fn add_torrent(
&self,
torrent_data: Vec<u8>,
save_path: PathBuf,
file_indices: Option<Vec<usize>>,
) -> Result<String> {
let opts = AddTorrentOptions {
overwrite: false,
only_files_regex: None,
only_files: file_indices,
output_folder: Some(save_path.to_string_lossy().to_string()),
..Default::default()
};
let add_torrent = AddTorrent::from_bytes(torrent_data);
let response = self
.session
.add_torrent(add_torrent, Some(opts))
.await
.context("Failed to add torrent")?;
let handle = match response {
AddTorrentResponse::Added(_, handle) => handle,
AddTorrentResponse::AlreadyManaged(_, handle) => handle,
AddTorrentResponse::ListOnly(_) => anyhow::bail!("Torrent was list-only, not added"),
};
let info_hash = handle.info_hash().as_string();
self.torrents
.write()
.await
.insert(info_hash.clone(), handle);
Ok(info_hash)
}
/// Get download progress for a torrent
pub async fn get_progress(&self, info_hash: &str) -> Result<DownloadProgress> {
let torrents = self.torrents.read().await;
let handle = torrents.get(info_hash).context("Torrent not found")?;
let stats = handle.stats();
Ok(DownloadProgress {
downloaded_bytes: stats.progress_bytes,
total_bytes: stats.total_bytes,
download_speed: stats.live.as_ref().map_or(0.0, |l| l.download_speed.mbps * 1024.0 * 1024.0),
upload_speed: stats.live.as_ref().map_or(0.0, |l| l.upload_speed.mbps * 1024.0 * 1024.0),
peers_connected: stats.live.as_ref().map_or(0, |l| l.snapshot.peer_stats.live as usize),
is_finished: stats.finished,
})
}
/// Wait until torrent download is completed
pub async fn wait_until_completed(&self, info_hash: &str) -> Result<()> {
let torrents = self.torrents.read().await;
let handle = torrents.get(info_hash).context("Torrent not found")?;
handle
.wait_until_completed()
.await
.context("Failed to wait for completion")?;
Ok(())
}
/// Enable seeding for a completed torrent
///
/// Note: rqbit seeds by default after completion, this is a no-op
/// but kept for API compatibility
pub async fn enable_seeding(&self, _info_hash: &str) -> Result<()> {
// rqbit automatically seeds after download completion
// This is kept for API compatibility
Ok(())
}
/// Remove a torrent from the session
pub async fn remove_torrent(&self, info_hash: &str) -> Result<()> {
let mut torrents = self.torrents.write().await;
if let Some(handle) = torrents.remove(info_hash) {
drop(handle);
}
Ok(())
}
/// Get list of all torrent info hashes in the session
pub async fn list_torrents(&self) -> Vec<String> {
self.torrents.read().await.keys().cloned().collect()
}
}

View File

@@ -0,0 +1,100 @@
//! Torrent file list parsing
//!
//! Provides functionality to extract file information from torrent metadata
//! without adding the torrent to a session.
use anyhow::{Context, Result};
use librqbit::torrent_from_bytes;
use serde::{Deserialize, Serialize};
/// Information about a file in a torrent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TorrentFileInfo {
/// File index (0-based)
pub index: usize,
/// File path relative to torrent root
pub path: String,
/// File size in bytes
pub size: u64,
}
/// Get the list of files in a torrent from its raw bytes
///
/// # Arguments
/// * `torrent_data` - Raw .torrent file contents
///
/// # Returns
/// List of file information (index, path, size)
pub fn get_torrent_file_list(torrent_data: &[u8]) -> Result<Vec<TorrentFileInfo>> {
let torrent_meta = torrent_from_bytes(torrent_data).context("Failed to parse torrent")?;
// Access the data inside WithRawBytes wrapper
let info = &torrent_meta.info.data;
let mut files = Vec::new();
// Handle both single-file and multi-file torrents
if let Some(ref file_list) = info.files {
// Multi-file torrent
for (index, file) in file_list.iter().enumerate() {
let path = file
.path
.iter()
.map(|buf| String::from_utf8_lossy(buf.0).to_string())
.collect::<Vec<_>>()
.join("/");
files.push(TorrentFileInfo {
index,
path,
size: file.length,
});
}
} else {
// Single-file torrent
let name = match &info.name {
Some(n) => String::from_utf8_lossy(n.0).to_string(),
None => String::new(),
};
files.push(TorrentFileInfo {
index: 0,
path: name,
size: info.length.unwrap_or(0),
});
}
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::get_embedded_torrents;
#[test]
fn test_get_torrent_file_list() {
// Use an embedded torrent for testing
let torrents = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
assert!(!torrents.is_empty(), "Expected to find embedded torrents");
for (variant, data) in torrents {
let files = get_torrent_file_list(&data).expect("Failed to parse torrent");
assert!(!files.is_empty(), "Expected files in {variant} variant");
// Verify file info makes sense
for file in &files {
assert!(!file.path.is_empty(), "File path should not be empty");
assert!(file.size > 0, "File size should be positive");
}
println!("Variant '{variant}' has {} files", files.len());
for file in files.iter().take(5) {
println!(" [{}] {} ({} bytes)", file.index, file.path, file.size);
}
}
}
}

View File

@@ -0,0 +1,185 @@
//! Fake tracker implementation for Exo topology-based peer discovery
//!
//! Instead of contacting real BitTorrent trackers, this module generates
//! tracker announce responses using Exo's cluster topology data.
use std::net::Ipv4Addr;
use anyhow::Result;
use crate::bencode::{AnnounceParams, BencodeValue};
/// Information about a peer in the Exo topology
#[derive(Debug, Clone)]
pub struct PeerInfo {
/// Unique node identifier in the Exo cluster
pub node_id: String,
/// IPv4 address of the peer
pub ip: Ipv4Addr,
/// BitTorrent listening port
pub port: u16,
/// Whether this peer has the complete torrent
pub has_complete: bool,
/// Priority for peer selection (higher = prefer)
pub priority: i32,
}
/// Topology data containing available peers
#[derive(Debug, Clone)]
pub struct TopologyData {
/// List of peers in the topology
pub peers: Vec<PeerInfo>,
}
/// Default announce interval in seconds
const DEFAULT_INTERVAL: i64 = 1800;
/// Handle a tracker announce request using Exo topology data
///
/// Returns a bencoded tracker response containing peers from the topology.
///
/// # Arguments
/// * `params` - Announce request parameters
/// * `topology` - Current Exo cluster topology
///
/// # Returns
/// Bencoded announce response as bytes
pub fn handle_announce(params: &AnnounceParams, topology: &TopologyData) -> Result<Vec<u8>> {
// Sort peers by priority (descending) for better peer selection
let mut peers: Vec<_> = topology.peers.iter().collect();
peers.sort_by(|a, b| b.priority.cmp(&a.priority));
let response = if params.compact {
// Compact format: 6 bytes per peer (4 IP + 2 port)
let mut peer_data = Vec::with_capacity(peers.len() * 6);
for peer in &peers {
peer_data.extend_from_slice(&peer.ip.octets());
peer_data.extend_from_slice(&peer.port.to_be_bytes());
}
BencodeValue::dict()
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
.insert("peers", BencodeValue::Bytes(peer_data))
} else {
// Non-compact format: list of dicts
let mut peer_list = BencodeValue::list();
for peer in &peers {
let peer_dict = BencodeValue::dict()
.insert("ip", BencodeValue::string(&peer.ip.to_string()))
.insert("port", BencodeValue::integer(i64::from(peer.port)))
.insert("peer id", BencodeValue::Bytes(vec![0u8; 20])); // Placeholder peer ID
peer_list = peer_list.push(peer_dict);
}
BencodeValue::dict()
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
.insert("peers", peer_list)
};
Ok(response.encode())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_params(compact: bool) -> AnnounceParams {
AnnounceParams {
info_hash: [0u8; 20],
peer_id: [0u8; 20],
port: 6881,
uploaded: 0,
downloaded: 0,
left: 1000,
compact,
event: None,
}
}
fn make_test_topology() -> TopologyData {
TopologyData {
peers: vec![
PeerInfo {
node_id: "node1".to_string(),
ip: Ipv4Addr::new(192, 168, 1, 1),
port: 6881,
has_complete: true,
priority: 10,
},
PeerInfo {
node_id: "node2".to_string(),
ip: Ipv4Addr::new(192, 168, 1, 2),
port: 6882,
has_complete: false,
priority: 5,
},
],
}
}
#[test]
fn test_compact_response() {
let params = make_test_params(true);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// Should contain "interval" and "peers" keys
assert!(response.starts_with(b"d"));
assert!(response.ends_with(b"e"));
// Verify we have 12 bytes of peer data (2 peers * 6 bytes)
// The compact peers field should be "12:<12 bytes>"
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.contains("8:interval"));
assert!(response_str.contains("5:peers"));
}
#[test]
fn test_non_compact_response() {
let params = make_test_params(false);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// Should contain peers as a list
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.contains("8:interval"));
assert!(response_str.contains("5:peers"));
assert!(response_str.contains("2:ip"));
assert!(response_str.contains("4:port"));
}
#[test]
fn test_peer_priority_ordering() {
let params = make_test_params(true);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// In compact format, first peer should be node1 (priority 10)
// which is 192.168.1.1:6881
// Look for the peer data after "5:peers12:"
let peers_marker = b"5:peers12:";
let pos = response
.windows(peers_marker.len())
.position(|w| w == peers_marker)
.unwrap();
let peer_data = &response[pos + peers_marker.len()..pos + peers_marker.len() + 6];
// First peer should be 192.168.1.1 (node1 with higher priority)
assert_eq!(&peer_data[0..4], &[192, 168, 1, 1]);
}
#[test]
fn test_empty_topology() {
let params = make_test_params(true);
let topology = TopologyData { peers: vec![] };
let response = handle_announce(&params, &topology).unwrap();
// Should still be valid bencoded response with empty peers
assert!(response.starts_with(b"d"));
assert!(response.ends_with(b"e"));
}
}

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi884e4:pathl9:README.mdeed6:lengthi1249e4:pathl19:chat_template.jinjaeed6:lengthi1848e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi52e4:pathl22:generation_config.jsoneed6:lengthi221164e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi11330e4:pathl20:tokenization_kimi.pyeed6:lengthi2738e4:pathl21:tokenizer_config.jsoneee4:name40:91fb4f9fd1de100104925196d62b8ee06fd2ad6012:piece lengthi262144e6:pieces40:<3A>C<EFBFBD>t:<3A><>I_<49>i*xg<78><04>s|,<2C>4S<34><53><EFBFBD>j<EFBFBD><6A><EFBFBD>S<EFBFBD><03>|d<>e8:url-list63:https://huggingface.co/mlx-community/Kimi-K2-Instruct-4bit/raw/e

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi864e4:pathl9:README.mdeed6:lengthi3442e4:pathl19:chat_template.jinjaeed6:lengthi3445e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi53e4:pathl22:generation_config.jsoneed6:lengthi129766e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi12597e4:pathl20:tokenization_kimi.pyeed6:lengthi4047e4:pathl21:tokenizer_config.jsoneee4:name40:035a0cdd221ae0dca6b03120e20704a251a7bc9b12:piece lengthi262144e6:pieces20:<3A>^<5E>9`<60>C<18><>Y<EFBFBD>-L<><4C>*EC*e8:url-list58:https://huggingface.co/mlx-community/Kimi-K2-Thinking/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:de2dfaf56839b7d0e834157d2401dee02726874d12:piece lengthi262144e6:pieces20:<3A>*_<1F><><EFBFBD><18>Tij<04><>+<2B>]<5D><>e8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-4bit/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:c5bfd839cd4cda0e5a39a97e00218d9c56e468af12:piece lengthi262144e6:pieces20:܌!<0E><><EFBFBD>TO<54><4F>4<><34><EFBFBD>P<EFBFBD>_Qe8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-8bit/raw/e

View File

@@ -0,0 +1,2 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi1033e4:pathl9:README.mdeed6:lengthi707e4:pathl17:added_tokens.jsoneed6:lengthi6722e4:pathl19:chat_template.jinjaeed6:lengthi1222e4:pathl11:config.jsoneed6:lengthi180e4:pathl22:generation_config.jsoneed6:lengthi1671853e4:pathl10:merges.txteed6:lengthi154390e4:pathl28:model.safetensors.index.jsoneed6:lengthi28881e4:pathl24:qwen3_xml_tool_parser.pyeed6:lengthi613e4:pathl23:special_tokens_map.jsoneed6:lengthi5405e4:pathl21:tokenizer_config.jsoneed6:lengthi2776833e4:pathl10:vocab.jsoneee4:name40:ca8dbf41071f579fbe3260f20bbe1ab896f7903112:piece lengthi262144e6:pieces360:<3A>3<EFBFBD>\<5C>PDE<44><45><17><><EFBFBD><06><06><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>c+<2B>h{" <0B><>_
m<EFBFBD> 7<><37><EFBFBD><EFBFBD>.<2E>h<14><>fm<66><6D>,<2C>w<EFBFBD><77>nOМ<4F><11><>"<22><><EFBFBD><EFBFBD>&j<><6A>_<EFBFBD><5F>"F<><46><EFBFBD>u<18>gU<67><08><><EFBFBD>QW<51><57><EFBFBD><EFBFBD>@qiiq<69><71>T<EFBFBD><54><EFBFBD>P<>lSJƤ<4A> \<5C><><EFBFBD>R!<21>=<3D><>v<EFBFBD><76><EFBFBD>F<EFBFBD>q9<71><39><EFBFBD><EFBFBD><01><><EFBFBD><EFBFBD><av<61>B@<40><> <09>z

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi75789955e4:pathl17:model.safetensorseee4:name40:f56bc6adfb74c794203dc8ca94e0bccfe2bcd6cc12:piece lengthi16777216e6:pieces100:QM0Ts@Ev<>XԄ=<3D>6_xhњU4=<3D><>7<EFBFBD>j<EFBFBD><6A><EFBFBD><18>F<EFBFBD>M<EFBFBD>q<EFBFBD><71><EFBFBD><EFBFBD>m>a<><61>H°*'<27>5<EFBFBD><35>/9B<39><42>^V<>4H9m<39><6D><EFBFBD><EFBFBD>0<EFBFBD>^z<><7A>+YS*<2A>M<EFBFBD><4D>G<EFBFBD>+<2B>.<02>h<EFBFBD>5e8:url-list62:https://huggingface.co/mlx-community/SmolLM-135M-4bit/resolve/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi845e4:pathl9:README.mdeed6:lengthi16738e4:pathl19:chat_template.jinjaeed6:lengthi50145e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi100431e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi4200e4:pathl21:tokenizer_config.jsoneee4:name40:81e5ac3ad0af6efb1298a8e8c7a10ed2990c137b12:piece lengthi262144e6:pieces20:ME<4D>TVE@ͯ<><4E>8<><38><EFBFBD>`e8:url-list63:https://huggingface.co/mlx-community/gpt-oss-120b-MXFP4-Q8/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi838e4:pathl9:README.mdeed6:lengthi33998e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi67046e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi21694e4:pathl21:tokenizer_config.jsoneee4:name40:f356f2747216d7e98fee755df25987459fc1908912:piece lengthi262144e6:pieces20:<3A><><EFBFBD><EFBFBD>ͥ<><CDA5><EFBFBD>g#`<60><>f<EFBFBD>x<EFBFBD><78>e8:url-list62:https://huggingface.co/mlx-community/gpt-oss-20b-MXFP4-Q4/raw/e

View File

@@ -0,0 +1,6 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi956e4:pathl9:README.mdeed6:lengthi207e4:pathl17:added_tokens.jsoneed6:lengthi848e4:pathl11:config.jsoneed6:lengthi441810e4:pathl10:merges.txteed6:lengthi25864e4:pathl28:model.safetensors.index.jsoneed6:lengthi801e4:pathl23:special_tokens_map.jsoneed6:lengthi3476578e4:pathl14:tokenizer.jsoneed6:lengthi9935e4:pathl21:tokenizer_config.jsoneed6:lengthi776995e4:pathl10:vocab.jsoneee4:name40:39b35eaa97282c34db81f61a983b4b83344e10f112:piece lengthi262144e6:pieces380:<3A>ih֨
[׬-<2D><><EFBFBD>}<7D><19><>U<EFBFBD>){[<5B>+<2B>7PU<><13>nR`<60><>g<EFBFBD> <0C>vH<76><78>q<EFBFBD><71>Lz<4C> <>Q<>Ĉ|Š<><C5A0><EFBFBD><EFBFBD>\<5C><>ehۢ<68>S<EFBFBD> <0B>#<23>g<EFBFBD>Y%@D<><D2A9><EFBFBD>}ޥXO<><4F><EFBFBD><EFBFBD><EFBFBD><06> <0C><><EFBFBD><EFBFBD><EFBFBD><1B>Y<EFBFBD>"<22><>|<7C>JH<4A> <0C>w<EFBFBD><05>MH<4D>*k<>@R<><52> 1i<31>|<7C>y<H<02><>H {<7B><14>
<EFBFBD><1B><P<><50>@<40><><16><>E<<3C><><EFBFBD>S<EFBFBD><53>|<7C><><EFBFBD>A
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>_<><5F>
;<3B>Rg<52><67><EFBFBD><>$<24><><EFBFBD>|@`X<58><7F>#<23><><EFBFBD>M<EFBFBD>$n-<2D><10><>i<EFBFBD>
9<>6ɝ@t<><74>j<EFBFBD><16>n<EFBFBD><6E><EFBFBD><EFBFBD>ɃH<C983><48>,<2C><>

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16447e4:pathl9:README.mdeed6:lengthi970e4:pathl11:config.jsoneed6:lengthi62518e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55421e4:pathl21:tokenizer_config.jsoneee4:name40:8103891b028a8933068e47751bc2acc10bb59aa212:piece lengthi262144e6:pieces20:<3A>l<EFBFBD>f<EFBFBD>7<>.<2E><><EFBFBD> <0B> a<><61>e8:url-list69:https://huggingface.co/mlx-community/llama-3.3-70b-instruct-fp16/raw/e

View File

@@ -23,6 +23,7 @@ workspace = true
[dependencies]
networking = { workspace = true }
downloads = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [

View File

@@ -0,0 +1,334 @@
//! Downloads module - BitTorrent downloads PyO3 bindings
use crate::ext::*;
use downloads::bencode::AnnounceParams;
use downloads::tracker::{PeerInfo, TopologyData, handle_announce as rust_handle_announce};
use downloads::{DownloadProgress, TorrentSession};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use std::net::Ipv4Addr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Handle a tracker announce request
///
/// Args:
/// params: Dictionary with announce parameters (info_hash, peer_id, port, etc.)
/// peers: List of peer dictionaries (node_id, ip, port, has_complete, priority)
///
/// Returns:
/// Bencoded announce response as bytes
#[pyfunction]
fn handle_tracker_announce(
py: Python<'_>,
params: &Bound<'_, PyDict>,
peers: &Bound<'_, pyo3::types::PyList>,
) -> PyResult<Py<PyBytes>> {
// Parse announce params
let info_hash = {
let info_hash_item = params
.get_item("info_hash")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing info_hash"))?;
let info_hash_bytes: &[u8] = info_hash_item.extract()?;
if info_hash_bytes.len() != 20 {
return Err(pyo3::exceptions::PyValueError::new_err(
"info_hash must be 20 bytes",
));
}
let mut info_hash = [0u8; 20];
info_hash.copy_from_slice(info_hash_bytes);
info_hash
};
let peer_id = {
let peer_id_item = params
.get_item("peer_id")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing peer_id"))?;
let peer_id_bytes: &[u8] = peer_id_item.extract()?;
if peer_id_bytes.len() != 20 {
return Err(pyo3::exceptions::PyValueError::new_err(
"peer_id must be 20 bytes",
));
}
let mut peer_id = [0u8; 20];
peer_id.copy_from_slice(peer_id_bytes);
peer_id
};
let port: u16 = params
.get_item("port")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
.extract()?;
let uploaded: u64 = params
.get_item("uploaded")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing uploaded"))?
.extract()?;
let downloaded: u64 = params
.get_item("downloaded")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing downloaded"))?
.extract()?;
let left: u64 = params
.get_item("left")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing left"))?
.extract()?;
let compact: bool = params
.get_item("compact")?
.map(|v| v.extract().unwrap_or(true))
.unwrap_or(true);
let announce_params = AnnounceParams {
info_hash,
peer_id,
port,
uploaded,
downloaded,
left,
compact,
event: None, // TODO: parse event if needed
};
// Parse peer list
let peer_infos: Result<Vec<PeerInfo>, PyErr> = peers
.iter()
.map(|peer_item| {
let peer_dict: &Bound<'_, PyDict> = peer_item.downcast()?;
let node_id: String = peer_dict
.get_item("node_id")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing node_id"))?
.extract()?;
let ip_str: String = peer_dict
.get_item("ip")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing ip"))?
.extract()?;
let ip: Ipv4Addr = ip_str
.parse()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid IP address"))?;
let port: u16 = peer_dict
.get_item("port")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
.extract()?;
let has_complete: bool = peer_dict
.get_item("has_complete")?
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(false))
.unwrap_or(false);
let priority: i32 = peer_dict
.get_item("priority")?
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(0))
.unwrap_or(0);
Ok(PeerInfo {
node_id,
ip,
port,
has_complete,
priority,
})
})
.collect();
let peer_infos = peer_infos?;
let topology = TopologyData { peers: peer_infos };
// Call Rust tracker handler
let response_bytes = rust_handle_announce(&announce_params, &topology).pyerr()?;
// Return as Python bytes
Ok(PyBytes::new(py, &response_bytes).unbind())
}
/// Get all embedded torrent variants for a model
///
/// Args:
/// model_id: Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
/// revision: Git commit hash
///
/// Returns:
/// List of (variant_name, torrent_data) tuples, e.g., [("small", bytes), ("large", bytes)]
/// Returns empty list if no torrents found.
#[pyfunction]
fn get_embedded_torrents(
py: Python<'_>,
model_id: String,
revision: String,
) -> PyResult<Vec<(String, Py<PyBytes>)>> {
let torrents = downloads::get_embedded_torrents(&model_id, &revision);
Ok(torrents
.into_iter()
.map(|(variant, data)| (variant, PyBytes::new(py, &data).unbind()))
.collect())
}
/// Get file list from torrent data
///
/// Args:
/// torrent_data: Raw .torrent file contents
///
/// Returns:
/// List of (index, path, size_bytes) tuples for each file in the torrent
#[pyfunction]
fn get_torrent_file_list(torrent_data: Vec<u8>) -> PyResult<Vec<(usize, String, u64)>> {
let files = downloads::get_torrent_file_list(&torrent_data).pyerr()?;
Ok(files
.into_iter()
.map(|f| (f.index, f.path, f.size))
.collect())
}
/// Python wrapper for TorrentSession
#[pyclass]
struct TorrentSessionHandle {
session: Arc<Mutex<TorrentSession>>,
}
#[pymethods]
impl TorrentSessionHandle {
/// Create a new torrent session
///
/// Args:
/// session_dir: Directory to store session state and downloads
#[new]
fn new(session_dir: String) -> PyResult<Self> {
let session_path = PathBuf::from(session_dir);
let session = tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { TorrentSession::new(session_path).await })
.pyerr()?;
Ok(Self {
session: Arc::new(Mutex::new(session)),
})
}
/// Add a torrent from bytes
///
/// Args:
/// torrent_data: Raw .torrent file contents
/// save_path: Where to save downloaded files
/// file_indices: Optional list of file indices to download
///
/// Returns:
/// Info hash as hex string
fn add_torrent(
&self,
_py: Python<'_>,
torrent_data: Vec<u8>,
save_path: String,
file_indices: Option<Vec<usize>>,
) -> PyResult<String> {
let session = Arc::clone(&self.session);
let save_path = PathBuf::from(save_path);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async {
session
.lock()
.await
.add_torrent(torrent_data, save_path, file_indices)
.await
})
.pyerr()
}
/// Get download progress for a torrent
///
/// Args:
/// info_hash: Torrent info hash
///
/// Returns:
/// Dictionary with progress information
fn get_progress(&self, py: Python<'_>, info_hash: String) -> PyResult<Py<PyDict>> {
let session = Arc::clone(&self.session);
let progress: DownloadProgress = tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.get_progress(&info_hash).await })
.pyerr()?;
let dict = PyDict::new(py);
dict.set_item("downloaded_bytes", progress.downloaded_bytes)?;
dict.set_item("total_bytes", progress.total_bytes)?;
dict.set_item("download_speed", progress.download_speed)?;
dict.set_item("upload_speed", progress.upload_speed)?;
dict.set_item("peers_connected", progress.peers_connected)?;
dict.set_item("is_finished", progress.is_finished)?;
Ok(dict.unbind())
}
/// Wait until torrent download is completed
///
/// Args:
/// info_hash: Torrent info hash
fn wait_until_completed(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.wait_until_completed(&info_hash).await })
.pyerr()
}
/// Enable seeding for a torrent
///
/// Args:
/// info_hash: Torrent info hash
fn enable_seeding(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.enable_seeding(&info_hash).await })
.pyerr()
}
/// Remove a torrent from the session
///
/// Args:
/// info_hash: Torrent info hash
fn remove_torrent(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.remove_torrent(&info_hash).await })
.pyerr()
}
/// List all torrents in the session
///
/// Returns:
/// List of info hashes
fn list_torrents(&self, _py: Python<'_>) -> PyResult<Vec<String>> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { Ok(session.lock().await.list_torrents().await) })
}
}
/// Downloads submodule
pub(crate) fn downloads_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(handle_tracker_announce, m)?)?;
m.add_function(wrap_pyfunction!(get_embedded_torrents, m)?)?;
m.add_function(wrap_pyfunction!(get_torrent_file_list, m)?)?;
m.add_class::<TorrentSessionHandle>()?;
Ok(())
}

View File

@@ -17,10 +17,12 @@
extern crate core;
mod allow_threading;
pub(crate) mod downloads;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::downloads::downloads_submodule;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
@@ -207,6 +209,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
downloads_submodule(m)?;
// top-level constructs
// TODO: ...

58
scripts/mktorrent.sh Executable file
View File

@@ -0,0 +1,58 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p mktorrent -p python3Packages.huggingface-hub -p git -p git-lfs
set -euo pipefail
set -x
MODEL="$1"
mkdir -p "$MODEL"
# Step 1: Clone/fetch the repo and get the hash of head
mkdir -p "$MODEL"
if test -d "$MODEL/git"; then
# Assert that the origin is correct
git -C "$MODEL/git" fetch
else
git clone "https://huggingface.co/$MODEL" "$MODEL/git"
fi
HASH=$(git -C "$MODEL/git" rev-parse origin/main)
LARGE_FILES=$(git -C "$MODEL/git" lfs ls-files --all --name-only)
SMALL_DIR="$MODEL/$HASH-small"
LARGE_DIR="$MODEL/$HASH-large"
mkdir -p "$SMALL_DIR" "$LARGE_DIR"
# Step 2: Prepare files. Two torrents: one for large files and one for metadata.
git -C "$MODEL/git" archive "$HASH" | tar -x -C "$SMALL_DIR"
echo "$LARGE_FILES" | xargs -I{} rm "$SMALL_DIR/{}"
echo "$LARGE_FILES" | xargs hf download "$MODEL" --revision "$HASH" --local-dir "$LARGE_DIR" --cache-dir "$(realpath .cache)" --include
if test -d "$LARGE_DIR/.cache"; then
echo ".cache created against our wishes, deleting it..."
rm -r "$LARGE_DIR/.cache"
fi
# Step 3: Create both torrents
mkdir -p "torrents/$MODEL/"
SMALL_TORRENT_PATH="torrents/$MODEL/${HASH}.small.torrent"
LARGE_TORRENT_PATH="torrents/$MODEL/${HASH}.large.torrent"
mktorrent "$SMALL_DIR/" --output="$SMALL_TORRENT_PATH" \
-n "$HASH" \
--web-seed="https://huggingface.co/$MODEL/raw/" \
--no-date \
--announce="udp://tracker.opentrackr.org:1337/announce"
# --private
mktorrent "$LARGE_DIR/" --output="$LARGE_TORRENT_PATH" \
-n "$HASH" \
--web-seed="https://huggingface.co/$MODEL/resolve/" \
--piece-length=24 \
--no-date \
--announce="udp://tracker.opentrackr.org:1337/announce"
# --private
echo "Successfully created torrent files in:"
echo "$SMALL_TORRENT_PATH"
echo "$LARGE_TORRENT_PATH"

View File

@@ -35,6 +35,7 @@ class Node:
api: API | None
node_id: NodeId
enable_torrents: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -66,7 +67,8 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
exo_shard_downloader(enable_torrents=args.enable_torrents),
initial_connection_messages=[],
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
@@ -74,7 +76,6 @@ class Node:
)
else:
worker = None
# We start every node with a master
master = Master(
node_id,
@@ -98,7 +99,7 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(router, worker, election, er_recv, master, api, node_id, args.enable_torrents)
async def run(self):
async with self._tg as tg:
@@ -175,7 +176,7 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
exo_shard_downloader(enable_torrents=self.enable_torrents),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -215,6 +216,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
enable_torrents: bool = False
@classmethod
def parse(cls) -> Self:
@@ -256,6 +258,12 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--enable-torrents",
action="store_true",
dest="enable_torrents",
help="Enable BitTorrent-based downloads (experimental)",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -5,9 +5,9 @@ from typing import cast
import anyio
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
@@ -178,6 +178,7 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.get("/_internal/announce")(self.tracker_announce)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -622,6 +623,65 @@ class API:
]
)
async def tracker_announce(self, request: Request) -> Response:
"""BitTorrent tracker announce endpoint for private tracker."""
try:
from exo_pyo3_bindings import handle_tracker_announce # type: ignore
except ImportError as e:
raise HTTPException(
status_code=501,
detail="Torrent support not available (exo_pyo3_bindings not installed)",
) from e
# Parse announce parameters from query string
query_params = dict(request.query_params)
# Extract required parameters
try:
info_hash_hex = query_params.get("info_hash", "")
peer_id_hex = query_params.get("peer_id", "")
# URL decode and convert to bytes
info_hash = bytes.fromhex(info_hash_hex) if info_hash_hex else b""
peer_id = bytes.fromhex(peer_id_hex) if peer_id_hex else b""
if len(info_hash) != 20 or len(peer_id) != 20:
raise ValueError("info_hash and peer_id must be 20 bytes")
params = {
"info_hash": info_hash,
"peer_id": peer_id,
"port": int(query_params.get("port", "6881")),
"uploaded": int(query_params.get("uploaded", "0")),
"downloaded": int(query_params.get("downloaded", "0")),
"left": int(query_params.get("left", "0")),
"compact": query_params.get("compact", "1") == "1",
}
except (ValueError, KeyError) as e:
raise HTTPException(
status_code=400,
detail=f"Invalid announce parameters: {e}",
) from e
# Build peer list from topology
# TODO: Implement _build_peer_list_from_topology() to extract peers from self.state.topology
peers = [] # For now, return empty peer list
# Call Rust tracker handler
try:
response_bytes: bytes = handle_tracker_announce(params, peers) # type: ignore
return Response(
content=response_bytes,
media_type="text/plain",
headers={"Content-Type": "text/plain"},
)
except Exception as e:
logger.error(f"Tracker announce error: {e}")
raise HTTPException(
status_code=500,
detail=f"Tracker announce failed: {e}",
) from e
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -16,3 +16,4 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
revision: str | None = None # Git commit hash for torrent lookup

View File

@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):

View File

@@ -12,10 +12,17 @@ from exo.worker.download.download_utils import RepoDownloadProgress, download_sh
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
return SingletonShardDownloader(
CachedShardDownloader(ResumableShardDownloader(max_parallel_downloads))
)
def exo_shard_downloader(
max_parallel_downloads: int = 8, enable_torrents: bool = False
) -> ShardDownloader:
if enable_torrents:
from exo.worker.download.torrent_downloader import TorrentShardDownloader
base = TorrentShardDownloader(max_parallel_downloads)
else:
base = ResumableShardDownloader(max_parallel_downloads)
return SingletonShardDownloader(CachedShardDownloader(base))
async def build_base_shard(model_id: str) -> ShardMetadata:

View File

@@ -0,0 +1,366 @@
"""Torrent-based shard downloader using BitTorrent protocol with private tracker.
This module implements downloading model shards via BitTorrent with:
- Private tracker integration (/_internal/announce endpoint)
- WebSeed (BEP 19) fallback to HTTPS
- Selective file download based on shard layer ranges
- Persistent seeding of downloaded content
"""
import asyncio
from collections.abc import AsyncIterator
from datetime import timedelta
from fnmatch import fnmatch
from pathlib import Path
from typing import Callable
from loguru import logger
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.download_utils import RepoDownloadProgress, resolve_allow_patterns
from exo.worker.download.shard_downloader import ShardDownloader
class TorrentShardDownloader(ShardDownloader):
"""Download model shards using BitTorrent with private tracker."""
def __init__(self, max_parallel_downloads: int = 8):
self.max_parallel_downloads = max_parallel_downloads
self._progress_callbacks: list[
Callable[[ShardMetadata, RepoDownloadProgress], None]
] = []
# Initialize TorrentSessionHandle
try:
from exo_pyo3_bindings import TorrentSessionHandle # type: ignore
session_dir = EXO_MODELS_DIR / "v2" / ".torrent_session"
session_dir.mkdir(parents=True, exist_ok=True)
self.session = TorrentSessionHandle(str(session_dir))
except ImportError:
logger.error("exo_pyo3_bindings not available, torrent downloads disabled")
self.session = None
def on_progress(
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
"""Register a progress callback."""
self._progress_callbacks.append(callback)
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
"""Download a model shard using BitTorrent.
Downloads all torrent variants for the model and applies file filtering
based on the shard's layer range.
Args:
shard: Shard metadata including model ID and layer range
config_only: If True, only download config files (not implemented for torrents yet)
Returns:
Path to the downloaded shard directory
Raises:
RuntimeError: If torrent file is not found or session not initialized
"""
if self.session is None:
raise RuntimeError(
"TorrentSessionHandle not initialized. "
"exo_pyo3_bindings module not available."
)
model_id = str(shard.model_meta.model_id)
# Resolve "main" branch to commit hash
revision = await self._resolve_revision(model_id, "main")
# Load all embedded torrent variants
variants = self._load_embedded_torrents(model_id, revision)
if not variants:
raise RuntimeError(
f"No torrents found for {model_id}@{revision}. "
f"Expected at: rust/downloads/torrents/{model_id}/{revision}.*.torrent. "
f"Please add torrent files or use HTTP downloads (--enable-torrents=False)."
)
# Build v2 path: ~/.exo/models/v2/{model_id}/{revision}
save_path = EXO_MODELS_DIR / "v2" / model_id / revision
save_path.mkdir(parents=True, exist_ok=True)
# Get allow patterns for file filtering (same as HTTP downloads)
allow_patterns = await resolve_allow_patterns(shard)
logger.debug(f"Torrent download allow patterns: {allow_patterns}")
# Track aggregate progress across all variants
total_downloaded = 0
total_size = 0
# Download each variant
for variant_name, torrent_data in variants:
# Get file list from torrent
file_list = self._get_torrent_file_list(torrent_data)
if not file_list:
logger.warning(f"No files in variant '{variant_name}', skipping")
continue
# Filter files using allow patterns
file_indices = self._filter_files(file_list, allow_patterns)
if not file_indices:
logger.debug(f"No matching files in variant '{variant_name}' after filtering")
continue
logger.info(
f"Adding torrent variant '{variant_name}' for {model_id}@{revision} "
f"({len(file_indices)} of {len(file_list)} files)"
)
try:
info_hash = self.session.add_torrent(
torrent_data, str(save_path), file_indices
)
except Exception as e:
# Torrent may already be added
logger.debug(f"Could not add variant '{variant_name}': {e}")
continue
# Wait for download with progress reporting
logger.info(f"Starting download for variant '{variant_name}' ({info_hash})")
while True:
progress_dict = self.session.get_progress(info_hash)
# Update aggregate progress
variant_downloaded = progress_dict["downloaded_bytes"]
variant_total = progress_dict["total_bytes"]
download_speed = progress_dict["download_speed"] # bytes/sec from rqbit
# Calculate ETA based on remaining bytes and current speed
current_downloaded = total_downloaded + variant_downloaded
current_total = total_size + variant_total
remaining_bytes = current_total - current_downloaded
eta = (
timedelta(seconds=remaining_bytes / download_speed)
if download_speed > 0
else timedelta(0)
)
# Report combined progress
progress = self._make_progress(
shard=shard,
model_id=model_id,
revision=revision,
downloaded=current_downloaded,
total=current_total,
status="in_progress",
speed=download_speed,
eta=eta,
)
self._report_progress(shard, progress)
if progress_dict["is_finished"]:
logger.info(f"Variant '{variant_name}' download completed")
total_downloaded += variant_total
total_size += variant_total
break
await asyncio.sleep(0.5)
# Enable persistent seeding for this variant
try:
self.session.enable_seeding(info_hash)
except Exception as e:
logger.warning(f"Could not enable seeding for '{variant_name}': {e}")
return save_path
async def get_shard_download_status(
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
"""Get download status for all shards."""
if self.session is None:
return
# List all torrents in the session
info_hashes = self.session.list_torrents()
# We don't have a straightforward way to map info_hash back to shard/path
# This would require tracking in the session or metadata
# For now, this method doesn't yield anything meaningful for torrents
for info_hash in info_hashes:
try:
_progress_dict = self.session.get_progress(info_hash)
# Can't create proper progress without shard context
# Skip yielding until we have proper tracking
logger.debug(f"Torrent {info_hash} status check - not implemented")
except Exception as e:
logger.error(f"Error getting status for {info_hash}: {e}")
# Yield nothing - torrents don't support this method well yet
return
yield # Make this an async generator
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
"""Get download status for a specific shard."""
model_id = str(shard.model_meta.model_id)
if self.session is None:
return self._make_progress(
shard=shard,
model_id=model_id,
revision="unknown",
downloaded=0,
total=0,
status="not_started",
)
# We would need to track info_hash -> shard mapping
# For now, return empty progress
return self._make_progress(
shard=shard,
model_id=model_id,
revision="unknown",
downloaded=0,
total=0,
status="not_started",
)
async def _resolve_revision(self, model_id: str, branch: str = "main") -> str:
"""Resolve branch name to commit hash using HuggingFace API.
Args:
model_id: Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
branch: Branch name (default: "main")
Returns:
Git commit hash (SHA)
"""
try:
from huggingface_hub import model_info
info = model_info(model_id, revision=branch)
return str(info.sha)
except Exception as e:
logger.error(f"Failed to resolve revision for {model_id}@{branch}: {e}")
# Fallback to "main" as revision if API call fails
logger.warning(f"Using branch name '{branch}' as revision")
return branch
def _load_embedded_torrents(
self, model_id: str, revision: str
) -> list[tuple[str, bytes]]:
"""Load all embedded torrent variants from Rust binary.
Args:
model_id: Model identifier
revision: Git commit hash
Returns:
List of (variant_name, torrent_bytes) tuples, e.g., [("small", ...), ("large", ...)]
"""
try:
from exo_pyo3_bindings import get_embedded_torrents # type: ignore
result: list[tuple[str, bytes]] = get_embedded_torrents(model_id, revision)
return result
except ImportError:
logger.warning("exo_pyo3_bindings not available, cannot load torrents")
return []
except Exception as e:
logger.error(f"Error loading torrents for {model_id}@{revision}: {e}")
return []
def _get_torrent_file_list(
self, torrent_data: bytes
) -> list[tuple[int, str, int]]:
"""Get file list from torrent data.
Args:
torrent_data: Raw torrent file bytes
Returns:
List of (index, path, size_bytes) tuples
"""
try:
from exo_pyo3_bindings import get_torrent_file_list # type: ignore
return get_torrent_file_list(torrent_data)
except ImportError:
logger.warning("exo_pyo3_bindings not available")
return []
except Exception as e:
logger.error(f"Error parsing torrent file list: {e}")
return []
def _filter_files(
self, file_list: list[tuple[int, str, int]], allow_patterns: list[str]
) -> list[int]:
"""Filter file list using glob patterns.
Args:
file_list: List of (index, path, size) tuples
allow_patterns: Glob patterns to match (e.g., ["*.json", "*.safetensors"])
Returns:
List of file indices that match at least one pattern
"""
matching_indices = []
for index, path, _size in file_list:
for pattern in allow_patterns:
if fnmatch(path, pattern):
matching_indices.append(index)
break
return matching_indices
def _make_progress(
self,
shard: ShardMetadata,
model_id: str,
revision: str,
downloaded: int,
total: int,
status: str,
speed: float = 0.0,
eta: timedelta | None = None,
) -> RepoDownloadProgress:
"""Create a RepoDownloadProgress object with proper fields.
Args:
shard: Shard metadata
model_id: Model identifier
revision: Git revision
downloaded: Downloaded bytes
total: Total bytes
status: Download status ("not_started", "in_progress", "complete")
speed: Download speed in bytes/second
eta: Estimated time remaining
Returns:
RepoDownloadProgress object
"""
return RepoDownloadProgress(
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=1 if status == "complete" else 0,
total_files=1,
downloaded_bytes=Memory.from_bytes(downloaded),
downloaded_bytes_this_session=Memory.from_bytes(downloaded),
total_bytes=Memory.from_bytes(total),
overall_speed=speed,
overall_eta=eta if eta is not None else timedelta(0),
status=status, # type: ignore
)
def _report_progress(
self, shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
"""Report progress to all registered callbacks."""
for callback in self._progress_callbacks:
try:
callback(shard, progress)
except Exception as e:
logger.error(f"Error in progress callback: {e}")

View File

@@ -217,7 +217,9 @@ class Worker:
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
@@ -364,7 +366,11 @@ class Worker:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
@@ -457,7 +463,9 @@ class Worker:
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
result = plan_mod.plan(
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}