Compare commits

..

10 Commits

Author SHA1 Message Date
Evan
93dab5b960 yay 2026-01-27 10:12:03 +00:00
Alex Cheema
bd4f0bf048 Fix download speed/ETA display for re-downloads (#1294)
## Motivation

After the download verification fix, when files are re-downloaded due to
upstream changes (size mismatch), the download progress displays
correctly (completion %, bytes, file counts), but speed shows 0 B/s and
ETA shows "--" for both overall and per-file progress.

## Changes

- Modified `on_progress_wrapper` in `src/exo/download/download_utils.py`
to detect re-download scenarios
- Added re-download detection: when `curr_bytes < previous_downloaded`,
the file was deleted and download restarted
- On re-download: reset `start_time` to current time and set
`downloaded_this_session = curr_bytes`
- Added two tests to `test_download_verification.py` covering
re-download and continuing download scenarios

## Why It Works

The bug occurred because:
1. `file_progress` is initialized with the OLD local file size (e.g.,
1.5GB)
2. When `_download_file` detects size mismatch, it deletes the file and
starts fresh
3. Progress callback receives small `curr_bytes` (e.g., 8KB) but
compares against old size
4. `downloaded_this_session = 0 + (8KB - 1.5GB) = -1.5GB` (negative!)
5. Negative session bytes → 0 or negative speed → ETA shows "--"

The fix detects when `curr_bytes < previous_downloaded` (indicating
re-download started) and resets tracking to treat it as a fresh
download.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, modify a file to change its size, restart exo,
verify speed/ETA display correctly during re-download

### Automated Testing
- Added `TestProgressResetOnRedownload` class with two tests:
- `test_progress_resets_correctly_on_redownload`: Verifies progress
resets correctly when re-download starts
- `test_progress_accumulates_on_continuing_download`: Verifies
continuing downloads still accumulate correctly
- All 11 download tests pass
- Type checking (basedpyright): 0 errors
- Linting (ruff): All checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 21:56:58 +00:00
rltakashige
cd8c01b7c8 Fix kv prefix cache (#1262)
## Motivation

OpenCode sends very large prompts, most of which are repeated on the
next call.

## Changes

Add prefix caching, reducing average time in prefill (in testing) from
40 seconds to 4. This massively improves user experience.

Also evicts KV caches from this prefix cache in a LRU-style manner. 

## Why It Works

We no longer prefill repeatedly but rather use kv cache stored in
memory. A future update may want to use storage to make the prefix cache
larger.

## Test Plan

### Manual Testing
Tested speedup on OpenCode

### Automated Testing
Added a lot of tests

---------

Co-authored-by: David Hind <davehind@yahoo.co.uk>
2026-01-26 20:13:58 +00:00
rltakashige
59e991ce15 Only ignore message if actually empty (#1292)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-26 19:33:23 +00:00
ciaranbor
ffba340e70 Ciaran/image quantization (#1272)
## Motivation

Enable users to select and use quantized variants (8-bit, 4-bit) of
image models

## Changes

Use exolabs HF org for image models

## Why It Works

Quantized versions have been uploaded to exolabs HF org

## Test Plan

Loaded and ran different quantized variants. Confirmed lower memory
usage and different outputs for the same seed. Verified chat completion
still works.
2026-01-26 19:25:05 +00:00
rltakashige
9968abe816 Leo/fix basic model shard (#1291)
## Motivation

Some models, on some configurations, would have several issues that
caused the model to be stuck on loading.

## Changes

Several loading issues were with upstream mlx lm shard loading for
tensor parallel.
GLM 4.7 Flash now uses GLM 4.7 Lite.
A final portion of the issues were from mlx memory not being properly
released before calling mx.eval(model), causing the system to run out of
memory.

## Test Plan

### Manual Testing
Done a bunch (thanks @AlexCheema), hopefully exhaustive. 

### Automated Testing
A bunch of automated testing is imminent but not landed yet.

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
2026-01-26 17:49:09 +00:00
Alex Cheema
0e30b0830f Fix download system for upstream file changes (#1290)
## Motivation

When upstream files change on Hugging Face, exo's download system
doesn't detect the change and downloads get stuck. The only workaround
is deleting `~/.exo/models/` and the cache.

Root causes:
1. Existing files are never re-verified against remote metadata
2. File list cache is never invalidated, causing stale sizes to be used

## Changes

1. **Verify existing files against remote size** (`_download_file`):
Before returning early for existing files, verify the local file size
matches remote. If mismatched, delete and re-download. If network fails
(offline), fall back to trusting local file.

2. **Always try fresh file list first** (`fetch_file_list_with_cache`):
Always attempt to fetch fresh data from Hugging Face. On success, update
the cache. On failure, fall back to cached data if available.

3. **Clear cache on model delete** (`delete_model`): When a model is
deleted, also delete its cache entry to prevent stale metadata.

## Why It Works

- **Online**: Stale local files are detected via size mismatch and
re-downloaded. Fresh file list is always fetched and cache is updated.
- **Offline with cache**: Existing files are trusted. Cached file list
is used as fallback.
- **Offline without cache**: Fails gracefully (can't download without
knowing what files to get).

The size check is O(1) so there's no performance impact. Hash
verification still happens after download completes (existing behavior).

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, manually modify a local file's content, restart exo,
verify it re-downloads

### Automated Testing
Added 9 new tests in
`src/exo/download/tests/test_download_verification.py`:
- Re-download when file size changes upstream
- Skip download when file size matches
- Offline fallback uses local file
- Fetch fresh file list and update cache
- Fall back to cache when fetch fails
- Error propagates when no cache exists
- Model delete clears cache
- Delete when only cache exists
- Delete nonexistent model

All tests pass: `uv run pytest src/exo/download/tests/ -v`

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 09:14:58 -08:00
Alex Cheema
44453c4c8b Remove change-detection checks from info gatherer monitors (#1283)
## Summary
- When a node times out, its info gets cleared from state. The monitor
functions only sent data when something changed, leaving no mechanism to
re-populate this info after a timeout.
- Removes change-detection checks from `_monitor_misc`,
`_monitor_system_profiler_thunderbolt_data`, `_watch_system_info`, and
`_monitor_thunderbolt_bridge_status` so data is sent periodically
regardless of whether it changed.

## Test plan
- [ ] Verify type checker passes: `uv run basedpyright`
- [ ] Verify linter passes: `uv run ruff check`
- [ ] Verify tests pass: `uv run pytest`
- [ ] Manually test that node info is re-populated after a timeout by
observing cluster behavior

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 12:23:22 +00:00
Jake Hillion
1290e8ed9f dashboard: fix prettier-svelte rebuilding on every file change
The prettier-svelte package was rebuilding whenever any file in the
repository changed because dashboardStubSrc referenced inputs.self
directly. Since inputs.self's store path hash is computed from the
entire repository contents, any file modification invalidated the
derivation.

Added dashboardLockfileSrc using lib.cleanSourceWith to filter
inputs.self to only include package.json and package-lock.json from
the dashboard directory. Updated dashboardStubSrc to reference this
filtered source instead of inputs.self directly.

This ensures prettier-svelte only rebuilds when the lockfiles actually
change, significantly improving build caching for unrelated changes.

Test plan:
- Built prettier-svelte with nix build .#prettier-svelte
- Modified src/exo/main.py and rebuilt - same store path (no rebuild)
- Modified dashboard/package.json and rebuilt - different store path (rebuild triggered)
- Ran nix flake check successfully
2026-01-26 12:02:05 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
30 changed files with 1875 additions and 534 deletions

View File

@@ -5,18 +5,18 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)

View File

@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -3,12 +3,28 @@
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to ONLY include package.json and package-lock.json
# This ensures prettier-svelte only rebuilds when lockfiles change
dashboardLockfileSrc = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
isDashboardDir = baseName == "dashboard" && type == "directory";
isPackageFile =
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
&& (baseName == "package.json" || baseName == "package-lock.json");
in
isDashboardDir || isPackageFile;
};
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"mlx-lm==0.30.5",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",

View File

@@ -121,11 +121,20 @@ async def ensure_models_dir() -> Path:
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
models_dir = await ensure_models_dir()
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
deleted = False
if await aios.path.exists(model_dir):
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
deleted = True
# Also clear cache
if await aios.path.exists(cache_dir):
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
return deleted
async def seed_models(seed_dir: str | Path):
@@ -151,16 +160,28 @@ async def fetch_file_list_with_cache(
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
async def fetch_file_list_with_retry(
@@ -332,8 +353,28 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
if await aios.path.exists(target_dir / path):
return target_dir / path
target_path = target_dir / path
if await aios.path.exists(target_path):
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
try:
remote_size, _ = await file_meta(model_id, revision, path)
if local_size != remote_size:
logger.info(
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
)
await aios.remove(target_path)
else:
return target_path
except Exception as e:
# Offline or network error - trust local file
logger.debug(
f"Could not verify {path} against remote (offline?): {e}, using local file"
)
return target_path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -542,17 +583,26 @@ async def download_shard(
async def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
else time.time()
)
downloaded_this_session = (
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
previous_progress = file_progress.get(file.path)
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
speed = (
downloaded_this_session / (time.time() - start_time)
if time.time() - start_time > 0

View File

View File

@@ -0,0 +1,451 @@
"""Tests for download verification and cache behavior."""
import time
from collections.abc import AsyncIterator
from datetime import timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from pydantic import TypeAdapter
from exo.download.download_utils import (
delete_model,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
"""Set up a temporary models directory for testing."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestFileVerification:
"""Tests for file size verification in _download_file."""
async def test_redownload_when_file_size_changes_upstream(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with mismatched sizes are re-downloaded."""
# Import inside test to allow patching
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file with wrong size
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content") # 13 bytes
remote_size = 1000 # Different from local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
# Set up mock HTTP response for re-download
mock_response = MagicMock()
mock_response.status = 200
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
side_effect=[b"x" * remote_size, b""]
)
mock_session = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
# Mock calc_hash to return the expected hash
with patch(
"exo.download.download_utils.calc_hash",
new_callable=AsyncMock,
return_value=remote_hash,
):
await _download_file(model_id, "main", "test.safetensors", target_dir)
# file_meta should be called twice: once for verification, once for download
assert mock_file_meta.call_count == 2
async def test_skip_download_when_file_size_matches(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with matching sizes are not re-downloaded."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
local_content = b"local content"
async with aiofiles.open(local_file, "wb") as f:
await f.write(local_content)
remote_size = len(local_content) # Same as local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return immediately without downloading
assert result == local_file
mock_file_meta.assert_called_once()
mock_session_factory.assert_not_called()
async def test_offline_fallback_uses_local_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that local files are used when network is unavailable."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content")
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return local file without attempting download
assert result == local_file
mock_session_factory.assert_not_called()
class TestFileListCache:
"""Tests for file list caching behavior."""
async def test_fetch_fresh_and_update_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that fresh data is fetched and cache is updated."""
models_dir = tmp_path / "models"
file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=100),
]
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=file_list,
) as mock_fetch,
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == file_list
mock_fetch.assert_called_once()
# Verify cache was written
cache_file = (
models_dir
/ "caches"
/ model_id.normalize()
/ f"{model_id.normalize()}--main--file_list.json"
)
assert await aios.path.exists(cache_file)
async with aiofiles.open(cache_file, "r") as f:
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
await f.read()
)
assert cached_data == file_list
async def test_fallback_to_cache_when_fetch_fails(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that cached data is used when fetch fails."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
# Create cache file
cached_file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
)
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_file_list
async def test_error_propagates_when_no_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that errors propagate when fetch fails and no cache exists."""
models_dir = tmp_path / "models"
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
pytest.raises(Exception, match="Network error"),
):
await fetch_file_list_with_cache(model_id, "main")
class TestModelDeletion:
"""Tests for model deletion including cache cleanup."""
async def test_delete_model_clears_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that deleting a model also deletes its cache."""
models_dir = tmp_path / "models"
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
# Create model and cache directories
await aios.makedirs(model_dir, exist_ok=True)
await aios.makedirs(cache_dir, exist_ok=True)
# Add some files
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
await f.write("model data")
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is True
assert not await aios.path.exists(model_dir)
assert not await aios.path.exists(cache_dir)
async def test_delete_model_only_cache_exists(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting when only cache exists (model already deleted)."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
# Only create cache directory
await aios.makedirs(cache_dir, exist_ok=True)
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
# Returns False because model dir didn't exist
assert result is False
# But cache should still be cleaned up
assert not await aios.path.exists(cache_dir)
async def test_delete_nonexistent_model(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting a model that doesn't exist."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is False
class TestProgressResetOnRedownload:
"""Tests for progress tracking when files are re-downloaded."""
async def test_progress_resets_correctly_on_redownload(
self, model_id: ModelId
) -> None:
"""Test that progress tracking resets when a file is re-downloaded.
When a file is deleted and re-downloaded (due to size mismatch),
the progress tracking should reset rather than calculating negative
downloaded_this_session values.
"""
# Simulate file_progress dict as it exists in download_shard
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with old file progress (simulating existing large file)
old_file_size = 1_500_000_000 # 1.5 GB
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(old_file_size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(old_file_size),
speed=0,
eta=timedelta(0),
status="not_started",
start_time=time.time() - 10, # Started 10 seconds ago
)
# Simulate the logic from on_progress_wrapper after re-download starts
# This is the exact logic from the fixed on_progress_wrapper
curr_bytes = 100_000 # 100 KB - new download just started
previous_progress = file_progress.get("model.safetensors")
# Detect re-download: curr_bytes < previous downloaded
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is True, "Should detect re-download scenario"
assert downloaded_this_session == curr_bytes, (
"downloaded_this_session should equal curr_bytes on re-download"
)
assert downloaded_this_session > 0, (
"downloaded_this_session should be positive, not negative"
)
# Calculate speed (should be positive)
elapsed = time.time() - start_time
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
assert speed >= 0, "Speed should be non-negative"
async def test_progress_accumulates_on_continuing_download(
self, model_id: ModelId
) -> None:
"""Test that progress accumulates correctly for continuing downloads.
When a download continues from where it left off (resume),
the progress should accumulate correctly.
"""
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with partial download progress
initial_downloaded = 500_000 # 500 KB already downloaded
start_time = time.time() - 5 # Started 5 seconds ago
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(initial_downloaded),
downloaded_this_session=Memory.from_bytes(initial_downloaded),
total=Memory.from_bytes(1_000_000),
speed=100_000,
eta=timedelta(seconds=5),
status="in_progress",
start_time=start_time,
)
# Progress callback with more bytes downloaded
curr_bytes = 600_000 # 600 KB - continuing download
previous_progress = file_progress.get("model.safetensors")
# This is NOT a re-download (curr_bytes > previous downloaded)
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
downloaded_this_session = curr_bytes
used_start_time = time.time()
else:
used_start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is False, (
"Should NOT detect re-download for continuing download"
)
assert used_start_time == start_time, "Should preserve original start_time"
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
assert downloaded_this_session == expected_session, (
f"Should accumulate: {downloaded_this_session} == {expected_session}"
)
assert downloaded_this_session == 600_000, (
"downloaded_this_session should equal total downloaded so far"
)

View File

@@ -88,7 +88,6 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -509,14 +508,16 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
@@ -900,11 +901,6 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -986,11 +982,6 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
)
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -248,7 +246,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
@@ -260,7 +258,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -270,7 +268,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -280,18 +278,6 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -300,9 +286,10 @@ class Master:
]
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -20,15 +20,9 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -186,7 +180,6 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -202,18 +195,6 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -428,7 +428,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +442,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -457,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -470,7 +470,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +484,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -499,7 +499,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
@@ -541,9 +541,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -551,10 +551,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -575,9 +575,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -585,10 +585,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -610,6 +610,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _generate_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
# quantizations = [8, 6, 5, 4, 3]
quantizations = [8, 4]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = ModelId(base_card.model_id + f"-{quant}bit")
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -88,7 +84,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -0,0 +1,12 @@
"""Shared types for MLX-related functionality."""
from collections.abc import Sequence
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -61,10 +60,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -92,7 +87,6 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -349,13 +349,8 @@ class InfoGatherer:
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await self.info_sender.send(await MiscData.gather())
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
@@ -365,15 +360,12 @@ class InfoGatherer:
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
@@ -398,22 +390,17 @@ class InfoGatherer:
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = await get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
if curr is not None:
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)

View File

@@ -19,6 +19,8 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
@@ -145,6 +147,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -252,10 +258,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call
@@ -334,15 +336,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -351,7 +345,6 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -367,6 +360,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GLM4MoeLiteModel):
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
@@ -441,7 +442,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -516,6 +517,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
return model
@@ -533,6 +536,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers), # type: ignore
on_timeout,
)
if layer.self_attn.q_lora_rank is None: # type: ignore
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
return model
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -566,7 +647,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
@@ -607,6 +688,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -661,7 +743,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model

View File

@@ -1,39 +1,81 @@
# type: ignore
# TODO: Fix this file, including types!
import os
from copy import deepcopy
from typing import Callable
from typing import Any, cast
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
trim_prompt_cache,
)
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.85
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
def __init__(self, tokenizer: TokenizerWrapper):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._last_used.clear()
def add_kv_cache(self, prompt: str, cache: KVCacheType):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
def update_kv_cache(
self,
index: int,
prompt: str,
cache: KVCacheType,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts[index] = tokenized_prompt
self.caches[index] = deepcopy(cache)
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
"""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
@@ -42,63 +84,127 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
return self.caches[i]
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
return prompt_cache
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]

View File

@@ -4,7 +4,7 @@
KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192
MAX_TOKENS: int = 32168
MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"

View File

@@ -1,12 +1,12 @@
import time
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -14,34 +14,78 @@ from exo.shared.types.api import (
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def warmup_inference(
@@ -89,6 +133,10 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -115,6 +163,7 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -126,7 +175,22 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
caches = make_kv_cache(model=model)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
prefix_hit_length = 0
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -139,11 +203,19 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
@@ -153,12 +225,13 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
@@ -180,4 +253,26 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -18,15 +18,12 @@ try:
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TRUST_REMOTE_CODE,
)
@@ -70,6 +67,8 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -87,6 +86,30 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -379,7 +402,11 @@ def apply_chat_template(
continue
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
if (
message.content is None
and message.thinking is None
and message.tool_calls is None
):
continue
# Null values are not valid when applying templates in tokenizer
@@ -436,31 +463,6 @@ class NullKVCache(KVCache):
raise NotImplementedError("We should not be setting a NullKVCache.")
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
def mlx_force_oom(size: int = 40000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
@@ -510,23 +512,3 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,7 +33,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -116,9 +115,8 @@ class Worker:
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
async with create_task_group() as tg:
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -222,22 +220,15 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await runner.start_task(task)
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -360,6 +351,8 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -60,8 +59,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _cancel_tasks(runners, tasks)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -272,7 +270,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -286,7 +284,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len(input_chunk_buffer.get(cmd_id, {}))
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -294,31 +292,16 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id, cancelled_task_id=task.task_id
)

View File

@@ -3,10 +3,11 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.tests.patches import load_null_model
logger: "loguru.Logger" = loguru.logger
@@ -15,8 +16,9 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
*,
_load_null_models: bool = False,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -30,6 +32,13 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
p = None
if _load_null_models:
from unittest.mock import patch
p = patch("mlx_lm.utils.load_model", new=load_null_model)
p.start()
global logger
logger = _logger
@@ -39,7 +48,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:
@@ -53,6 +62,8 @@ def entrypoint(
)
)
finally:
if p is not None:
p.stop()
try:
event_sender.close()
task_receiver.close()

View File

@@ -37,7 +37,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -71,6 +70,7 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -78,7 +78,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -87,7 +86,6 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -102,13 +100,11 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -117,7 +113,6 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -162,25 +157,28 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -191,11 +189,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert not isinstance(model, DistributedImageModel)
assert tokenizer
toks = warmup_inference(
model=inference_model,
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -207,8 +205,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -227,7 +225,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert inference_model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
@@ -239,10 +237,11 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# For other thinking models (GLM, etc.), check if we need to
@@ -262,11 +261,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -278,19 +277,7 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
cancel_every = 5
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -357,7 +344,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -371,9 +358,7 @@ def main(
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -420,7 +405,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -432,9 +417,7 @@ def main(
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -497,7 +480,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
del model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -49,12 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -65,8 +63,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -74,7 +72,6 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -89,7 +86,6 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,41 +93,37 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with self._tg as tg:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
@@ -139,7 +131,6 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -149,13 +140,7 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:
@@ -221,4 +206,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
await self.shutdown()
self.shutdown()

View File

@@ -0,0 +1,50 @@
# type: ignore
import importlib
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from exo.worker.engines.mlx import Model
def load_null_model(path: Path, **_: object) -> "tuple[Model, dict[str, Any]]":
with open(path / "config.json", "r") as f:
cfg = json.load(f)
model, args = _get_classes(cfg)
model = model(args.from_dict(cfg))
return model, cfg
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
raise ValueError(msg) from None
return arch.Model, arch.ModelArgs
MODEL_REMAPPING = {
"mistral": "llama",
"llava": "mistral3",
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
"kimi_k2": "deepseek_v3",
"qwen2_5_vl": "qwen2_vl",
"minimax_m2": "minimax",
"iquestcoder": "llama",
}

View File

@@ -0,0 +1,545 @@
# type: ignore
import time
from typing import cast
from unittest.mock import patch
import mlx.core as mx
import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
encode_prompt,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
DEFAULT_GPT_OSS_MODEL_ID,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
@pytest.fixture
def mock_tokenizer(self):
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
from unittest.mock import MagicMock
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.clear()
assert len(cache.prompts) == 0
def _load_gpt_oss() -> tuple[Model, object]:
from mlx_lm.utils import load_model
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
model, _ = load_model(model_path, lazy=False)
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
return cast(Model, model), tokenizer
@pytest.mark.slow
@pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
)
class TestKVPrefixCacheWithModel:
@pytest.fixture(scope="class")
def model_and_tokenizer(self):
model, tokenizer = _load_gpt_oss()
return model, tokenizer
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, long_prompt
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
):
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra_keys, extra_values)
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
):
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra = mx.random.normal((1, num_heads, 30, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# Consume the entire generator so the cache-saving code after yield runs
generated_tokens = 0
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
generated_tokens += 1
assert len(kv_prefix_cache.prompts) == 1
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
assert len(kv_prefix_cache.prompts) == 1
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
assert len(prompt1_tokens) > 1000, (
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
)
# First generation populates the cache (must prefill all tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
):
pass
second_gen_time = time.perf_counter() - t0
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
assert second_gen_time < first_gen_time * 0.5, (
f"Expected prefix cache speedup: "
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
)
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
assert len(kv_prefix_cache.prompts) == 3
# Access the third entry to make it most recently used
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
):
# Trigger eviction by adding a new entry
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -1,7 +1,6 @@
import multiprocessing as mp
import socket
import time
import typing
import anyio
from fastapi import FastAPI
@@ -11,16 +10,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,18 +31,17 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import RunnerFailed, RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
MODEL_CARDS = {"haha": MODEL_CARDS["qwen3-coder-480b-a35b-8bit"]}
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
mp.set_start_method("spawn", force=True)
@@ -56,16 +50,14 @@ logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:8000"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/run_test")(run_test)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
@@ -87,28 +79,7 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):
async def run_test(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
@@ -117,10 +88,30 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
for card in MODEL_CARDS.values():
for instance in (
ring_instance(test, card.model_id, iid, hn),
jaccl_instance(test, card.model_id, iid),
):
recv = await execute_test(test, instance, hn)
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, RunnerFailed
):
return
except anyio.ClosedResourceError:
pass
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
def ring_instance(test: Tests, model_id: ModelId, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -135,13 +126,13 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,7 +154,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
@@ -171,60 +162,33 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
for card in MODEL_CARDS.values():
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=card.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
return recv
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests, model_id: ModelId, iid: InstanceId):
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
world_size = len(test.devs)
return MlxJacclInstance(
@@ -235,7 +199,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
@@ -270,6 +234,7 @@ def new_runner(
task_recv,
logger,
),
kwargs={"_load_null_models": True},
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()

View File

@@ -6,19 +6,8 @@ query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
if [[ $# -lt 1 ]]; then
echo "USAGE: $0 [host1] [host2] ..."
exit 1
fi
@@ -34,23 +23,12 @@ done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
for i in "${!ips[@]}"; do
{
curl -sN \
-X POST "http://${ips[$i]}:8000/run_test" \
-H "Content-Type: application/json" -d "{\"devs\": ${devs}}" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait

16
uv.lock generated
View File

@@ -415,7 +415,7 @@ requires-dist = [
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "mlx-lm", specifier = "==0.30.5" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" },
@@ -1072,8 +1072,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
version = "0.30.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1083,6 +1083,10 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
]
[[package]]
name = "mlx-metal"
@@ -2281,7 +2285,7 @@ wheels = [
[[package]]
name = "transformers"
version = "5.0.0rc2"
version = "5.0.0rc3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2296,9 +2300,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" }
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" },
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
]
[[package]]