Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
cc09ba01e4 yay 2026-01-26 18:19:00 +00:00
13 changed files with 219 additions and 886 deletions

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 30
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm==0.30.5",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",

View File

@@ -121,20 +121,11 @@ async def ensure_models_dir() -> Path:
async def delete_model(model_id: ModelId) -> bool:
models_dir = await ensure_models_dir()
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
deleted = False
if await aios.path.exists(model_dir):
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
deleted = True
# Also clear cache
if await aios.path.exists(cache_dir):
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
return deleted
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
async def seed_models(seed_dir: str | Path):
@@ -160,28 +151,16 @@ async def fetch_file_list_with_cache(
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
async def fetch_file_list_with_retry(
@@ -353,28 +332,8 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
try:
remote_size, _ = await file_meta(model_id, revision, path)
if local_size != remote_size:
logger.info(
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
)
await aios.remove(target_path)
else:
return target_path
except Exception as e:
# Offline or network error - trust local file
logger.debug(
f"Could not verify {path} against remote (offline?): {e}, using local file"
)
return target_path
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -583,26 +542,17 @@ async def download_shard(
async def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
previous_progress = file_progress.get(file.path)
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
else time.time()
)
downloaded_this_session = (
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
speed = (
downloaded_this_session / (time.time() - start_time)
if time.time() - start_time > 0

View File

View File

@@ -1,451 +0,0 @@
"""Tests for download verification and cache behavior."""
import time
from collections.abc import AsyncIterator
from datetime import timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from pydantic import TypeAdapter
from exo.download.download_utils import (
delete_model,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
"""Set up a temporary models directory for testing."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestFileVerification:
"""Tests for file size verification in _download_file."""
async def test_redownload_when_file_size_changes_upstream(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with mismatched sizes are re-downloaded."""
# Import inside test to allow patching
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file with wrong size
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content") # 13 bytes
remote_size = 1000 # Different from local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
# Set up mock HTTP response for re-download
mock_response = MagicMock()
mock_response.status = 200
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
side_effect=[b"x" * remote_size, b""]
)
mock_session = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
# Mock calc_hash to return the expected hash
with patch(
"exo.download.download_utils.calc_hash",
new_callable=AsyncMock,
return_value=remote_hash,
):
await _download_file(model_id, "main", "test.safetensors", target_dir)
# file_meta should be called twice: once for verification, once for download
assert mock_file_meta.call_count == 2
async def test_skip_download_when_file_size_matches(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with matching sizes are not re-downloaded."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
local_content = b"local content"
async with aiofiles.open(local_file, "wb") as f:
await f.write(local_content)
remote_size = len(local_content) # Same as local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return immediately without downloading
assert result == local_file
mock_file_meta.assert_called_once()
mock_session_factory.assert_not_called()
async def test_offline_fallback_uses_local_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that local files are used when network is unavailable."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content")
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return local file without attempting download
assert result == local_file
mock_session_factory.assert_not_called()
class TestFileListCache:
"""Tests for file list caching behavior."""
async def test_fetch_fresh_and_update_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that fresh data is fetched and cache is updated."""
models_dir = tmp_path / "models"
file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=100),
]
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=file_list,
) as mock_fetch,
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == file_list
mock_fetch.assert_called_once()
# Verify cache was written
cache_file = (
models_dir
/ "caches"
/ model_id.normalize()
/ f"{model_id.normalize()}--main--file_list.json"
)
assert await aios.path.exists(cache_file)
async with aiofiles.open(cache_file, "r") as f:
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
await f.read()
)
assert cached_data == file_list
async def test_fallback_to_cache_when_fetch_fails(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that cached data is used when fetch fails."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
# Create cache file
cached_file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
)
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_file_list
async def test_error_propagates_when_no_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that errors propagate when fetch fails and no cache exists."""
models_dir = tmp_path / "models"
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
pytest.raises(Exception, match="Network error"),
):
await fetch_file_list_with_cache(model_id, "main")
class TestModelDeletion:
"""Tests for model deletion including cache cleanup."""
async def test_delete_model_clears_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that deleting a model also deletes its cache."""
models_dir = tmp_path / "models"
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
# Create model and cache directories
await aios.makedirs(model_dir, exist_ok=True)
await aios.makedirs(cache_dir, exist_ok=True)
# Add some files
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
await f.write("model data")
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is True
assert not await aios.path.exists(model_dir)
assert not await aios.path.exists(cache_dir)
async def test_delete_model_only_cache_exists(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting when only cache exists (model already deleted)."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
# Only create cache directory
await aios.makedirs(cache_dir, exist_ok=True)
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
# Returns False because model dir didn't exist
assert result is False
# But cache should still be cleaned up
assert not await aios.path.exists(cache_dir)
async def test_delete_nonexistent_model(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting a model that doesn't exist."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is False
class TestProgressResetOnRedownload:
"""Tests for progress tracking when files are re-downloaded."""
async def test_progress_resets_correctly_on_redownload(
self, model_id: ModelId
) -> None:
"""Test that progress tracking resets when a file is re-downloaded.
When a file is deleted and re-downloaded (due to size mismatch),
the progress tracking should reset rather than calculating negative
downloaded_this_session values.
"""
# Simulate file_progress dict as it exists in download_shard
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with old file progress (simulating existing large file)
old_file_size = 1_500_000_000 # 1.5 GB
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(old_file_size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(old_file_size),
speed=0,
eta=timedelta(0),
status="not_started",
start_time=time.time() - 10, # Started 10 seconds ago
)
# Simulate the logic from on_progress_wrapper after re-download starts
# This is the exact logic from the fixed on_progress_wrapper
curr_bytes = 100_000 # 100 KB - new download just started
previous_progress = file_progress.get("model.safetensors")
# Detect re-download: curr_bytes < previous downloaded
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is True, "Should detect re-download scenario"
assert downloaded_this_session == curr_bytes, (
"downloaded_this_session should equal curr_bytes on re-download"
)
assert downloaded_this_session > 0, (
"downloaded_this_session should be positive, not negative"
)
# Calculate speed (should be positive)
elapsed = time.time() - start_time
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
assert speed >= 0, "Speed should be non-negative"
async def test_progress_accumulates_on_continuing_download(
self, model_id: ModelId
) -> None:
"""Test that progress accumulates correctly for continuing downloads.
When a download continues from where it left off (resume),
the progress should accumulate correctly.
"""
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with partial download progress
initial_downloaded = 500_000 # 500 KB already downloaded
start_time = time.time() - 5 # Started 5 seconds ago
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(initial_downloaded),
downloaded_this_session=Memory.from_bytes(initial_downloaded),
total=Memory.from_bytes(1_000_000),
speed=100_000,
eta=timedelta(seconds=5),
status="in_progress",
start_time=start_time,
)
# Progress callback with more bytes downloaded
curr_bytes = 600_000 # 600 KB - continuing download
previous_progress = file_progress.get("model.safetensors")
# This is NOT a re-download (curr_bytes > previous downloaded)
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
downloaded_this_session = curr_bytes
used_start_time = time.time()
else:
used_start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is False, (
"Should NOT detect re-download for continuing download"
)
assert used_start_time == start_time, "Should preserve original start_time"
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
assert downloaded_this_session == expected_session, (
f"Should accumulate: {downloaded_this_session} == {expected_session}"
)
assert downloaded_this_session == 600_000, (
"downloaded_this_session should equal total downloaded so far"
)

View File

@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("exolabs/FLUX.1-schnell"),
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -428,7 +428,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +442,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -457,7 +457,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"flux1-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-dev"),
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -470,7 +470,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +484,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -499,7 +499,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
@@ -541,9 +541,9 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image": ModelCard(
model_id=ModelId("exolabs/Qwen-Image"),
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -551,10 +551,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
@@ -575,9 +575,9 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -585,10 +585,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
@@ -610,92 +610,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _generate_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
# quantizations = [8, 6, 5, 4, 3]
quantizations = [8, 4]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = ModelId(base_card.model_id + f"-{quant}bit")
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -19,8 +19,6 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
@@ -147,10 +145,6 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -258,6 +252,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call
@@ -336,7 +334,15 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -345,6 +351,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -360,14 +367,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GLM4MoeLiteModel):
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
@@ -442,7 +441,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -517,8 +516,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
return model
@@ -536,84 +533,6 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers), # type: ignore
on_timeout,
)
if layer.self_attn.q_lora_rank is None: # type: ignore
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
return model
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -647,7 +566,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
@@ -688,7 +607,6 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -743,7 +661,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -405,11 +405,7 @@ def apply_chat_template(
continue
message.content = "\n".join(c.text for c in message.content).strip()
if (
message.content is None
and message.thinking is None
and message.tool_calls is None
):
if message.content is None and message.thinking is None:
continue
# Null values are not valid when applying templates in tokenizer

View File

@@ -7,6 +7,7 @@ from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.tests.patches import load_null_model
logger: "loguru.Logger" = loguru.logger
@@ -16,6 +17,8 @@ def entrypoint(
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
*,
_load_null_models: bool = False,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -29,6 +32,13 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
p = None
if _load_null_models:
from unittest.mock import patch
p = patch("mlx_lm.utils.load_model", new=load_null_model)
p.start()
global logger
logger = _logger
@@ -52,6 +62,8 @@ def entrypoint(
)
)
finally:
if p is not None:
p.stop()
try:
event_sender.close()
task_receiver.close()

View File

@@ -0,0 +1,50 @@
# type: ignore
import importlib
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from exo.worker.engines.mlx import Model
def load_null_model(path: Path, **_: object) -> "tuple[Model, dict[str, Any]]":
with open(path / "config.json", "r") as f:
cfg = json.load(f)
model, args = _get_classes(cfg)
model = model(args.from_dict(cfg))
return model, cfg
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
raise ValueError(msg) from None
return arch.Model, arch.ModelArgs
MODEL_REMAPPING = {
"mistral": "llama",
"llava": "mistral3",
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
"kimi_k2": "deepseek_v3",
"qwen2_5_vl": "qwen2_vl",
"minimax_m2": "minimax",
"iquestcoder": "llama",
}

View File

@@ -1,7 +1,6 @@
import multiprocessing as mp
import socket
import time
import typing
import anyio
from fastapi import FastAPI
@@ -11,16 +10,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,18 +31,17 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import RunnerFailed, RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
MODEL_CARDS = {"haha": MODEL_CARDS["qwen3-coder-480b-a35b-8bit"]}
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
mp.set_start_method("spawn", force=True)
@@ -56,16 +50,14 @@ logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:8000"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/run_test")(run_test)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
@@ -87,28 +79,7 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):
async def run_test(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
@@ -117,10 +88,30 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
for card in MODEL_CARDS.values():
for instance in (
ring_instance(test, card.model_id, iid, hn),
jaccl_instance(test, card.model_id, iid),
):
recv = await execute_test(test, instance, hn)
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, RunnerFailed
):
return
except anyio.ClosedResourceError:
pass
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
def ring_instance(test: Tests, model_id: ModelId, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -135,13 +126,13 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,7 +154,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
@@ -171,60 +162,33 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
for card in MODEL_CARDS.values():
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=card.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
return recv
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests, model_id: ModelId, iid: InstanceId):
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
world_size = len(test.devs)
return MlxJacclInstance(
@@ -235,7 +199,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
@@ -270,6 +234,7 @@ def new_runner(
task_recv,
logger,
),
kwargs={"_load_null_models": True},
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()

View File

@@ -6,19 +6,8 @@ query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
if [[ $# -lt 1 ]]; then
echo "USAGE: $0 [host1] [host2] ..."
exit 1
fi
@@ -34,23 +23,12 @@ done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
for i in "${!ips[@]}"; do
{
curl -sN \
-X POST "http://${ips[$i]}:8000/run_test" \
-H "Content-Type: application/json" -d "{\"devs\": ${devs}}" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait

52
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -412,10 +412,10 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", specifier = "==0.30.5" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" },
@@ -994,8 +994,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1022,12 +1022,18 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1040,14 +1046,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1074,20 +1072,26 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.5"
source = { registry = "https://pypi.org/simple" }
version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
@@ -2277,7 +2281,7 @@ wheels = [
[[package]]
name = "transformers"
version = "5.0.0rc3"
version = "5.0.0rc2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2292,9 +2296,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" },
]
[[package]]