Compare commits

...

6 Commits

Author SHA1 Message Date
Alex Cheema
bd4f0bf048 Fix download speed/ETA display for re-downloads (#1294)
## Motivation

After the download verification fix, when files are re-downloaded due to
upstream changes (size mismatch), the download progress displays
correctly (completion %, bytes, file counts), but speed shows 0 B/s and
ETA shows "--" for both overall and per-file progress.

## Changes

- Modified `on_progress_wrapper` in `src/exo/download/download_utils.py`
to detect re-download scenarios
- Added re-download detection: when `curr_bytes < previous_downloaded`,
the file was deleted and download restarted
- On re-download: reset `start_time` to current time and set
`downloaded_this_session = curr_bytes`
- Added two tests to `test_download_verification.py` covering
re-download and continuing download scenarios

## Why It Works

The bug occurred because:
1. `file_progress` is initialized with the OLD local file size (e.g.,
1.5GB)
2. When `_download_file` detects size mismatch, it deletes the file and
starts fresh
3. Progress callback receives small `curr_bytes` (e.g., 8KB) but
compares against old size
4. `downloaded_this_session = 0 + (8KB - 1.5GB) = -1.5GB` (negative!)
5. Negative session bytes → 0 or negative speed → ETA shows "--"

The fix detects when `curr_bytes < previous_downloaded` (indicating
re-download started) and resets tracking to treat it as a fresh
download.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, modify a file to change its size, restart exo,
verify speed/ETA display correctly during re-download

### Automated Testing
- Added `TestProgressResetOnRedownload` class with two tests:
- `test_progress_resets_correctly_on_redownload`: Verifies progress
resets correctly when re-download starts
- `test_progress_accumulates_on_continuing_download`: Verifies
continuing downloads still accumulate correctly
- All 11 download tests pass
- Type checking (basedpyright): 0 errors
- Linting (ruff): All checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 21:56:58 +00:00
rltakashige
cd8c01b7c8 Fix kv prefix cache (#1262)
## Motivation

OpenCode sends very large prompts, most of which are repeated on the
next call.

## Changes

Add prefix caching, reducing average time in prefill (in testing) from
40 seconds to 4. This massively improves user experience.

Also evicts KV caches from this prefix cache in a LRU-style manner. 

## Why It Works

We no longer prefill repeatedly but rather use kv cache stored in
memory. A future update may want to use storage to make the prefix cache
larger.

## Test Plan

### Manual Testing
Tested speedup on OpenCode

### Automated Testing
Added a lot of tests

---------

Co-authored-by: David Hind <davehind@yahoo.co.uk>
2026-01-26 20:13:58 +00:00
rltakashige
59e991ce15 Only ignore message if actually empty (#1292)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-26 19:33:23 +00:00
ciaranbor
ffba340e70 Ciaran/image quantization (#1272)
## Motivation

Enable users to select and use quantized variants (8-bit, 4-bit) of
image models

## Changes

Use exolabs HF org for image models

## Why It Works

Quantized versions have been uploaded to exolabs HF org

## Test Plan

Loaded and ran different quantized variants. Confirmed lower memory
usage and different outputs for the same seed. Verified chat completion
still works.
2026-01-26 19:25:05 +00:00
rltakashige
9968abe816 Leo/fix basic model shard (#1291)
## Motivation

Some models, on some configurations, would have several issues that
caused the model to be stuck on loading.

## Changes

Several loading issues were with upstream mlx lm shard loading for
tensor parallel.
GLM 4.7 Flash now uses GLM 4.7 Lite.
A final portion of the issues were from mlx memory not being properly
released before calling mx.eval(model), causing the system to run out of
memory.

## Test Plan

### Manual Testing
Done a bunch (thanks @AlexCheema), hopefully exhaustive. 

### Automated Testing
A bunch of automated testing is imminent but not landed yet.

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
2026-01-26 17:49:09 +00:00
Alex Cheema
0e30b0830f Fix download system for upstream file changes (#1290)
## Motivation

When upstream files change on Hugging Face, exo's download system
doesn't detect the change and downloads get stuck. The only workaround
is deleting `~/.exo/models/` and the cache.

Root causes:
1. Existing files are never re-verified against remote metadata
2. File list cache is never invalidated, causing stale sizes to be used

## Changes

1. **Verify existing files against remote size** (`_download_file`):
Before returning early for existing files, verify the local file size
matches remote. If mismatched, delete and re-download. If network fails
(offline), fall back to trusting local file.

2. **Always try fresh file list first** (`fetch_file_list_with_cache`):
Always attempt to fetch fresh data from Hugging Face. On success, update
the cache. On failure, fall back to cached data if available.

3. **Clear cache on model delete** (`delete_model`): When a model is
deleted, also delete its cache entry to prevent stale metadata.

## Why It Works

- **Online**: Stale local files are detected via size mismatch and
re-downloaded. Fresh file list is always fetched and cache is updated.
- **Offline with cache**: Existing files are trusted. Cached file list
is used as fallback.
- **Offline without cache**: Fails gracefully (can't download without
knowing what files to get).

The size check is O(1) so there's no performance impact. Hash
verification still happens after download completes (existing behavior).

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, manually modify a local file's content, restart exo,
verify it re-downloads

### Automated Testing
Added 9 new tests in
`src/exo/download/tests/test_download_verification.py`:
- Re-download when file size changes upstream
- Skip download when file size matches
- Offline fallback uses local file
- Fetch fresh file list and update cache
- Fall back to cache when fetch fails
- Error propagates when no cache exists
- Model delete clears cache
- Delete when only cache exists
- Delete nonexistent model

All tests pass: `uv run pytest src/exo/download/tests/ -v`

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 09:14:58 -08:00
14 changed files with 1580 additions and 176 deletions

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0", "anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'", "mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'", "mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2", "mlx-lm==0.30.5",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer "tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0", "hypercorn>=0.18.0",
"openai-harmony>=0.0.8", "openai-harmony>=0.0.8",

View File

@@ -121,11 +121,20 @@ async def ensure_models_dir() -> Path:
async def delete_model(model_id: ModelId) -> bool: async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize() models_dir = await ensure_models_dir()
if not await aios.path.exists(model_dir): model_dir = models_dir / model_id.normalize()
return False cache_dir = models_dir / "caches" / model_id.normalize()
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True deleted = False
if await aios.path.exists(model_dir):
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
deleted = True
# Also clear cache
if await aios.path.exists(cache_dir):
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
return deleted
async def seed_models(seed_dir: str | Path): async def seed_models(seed_dir: str | Path):
@@ -151,16 +160,28 @@ async def fetch_file_list_with_cache(
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize() target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True) await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json" cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f: # Always try fresh first
return TypeAdapter(list[FileListEntry]).validate_json(await f.read()) try:
file_list = await fetch_file_list_with_retry( file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive model_id, revision, recursive=recursive
) )
await aios.makedirs(cache_file.parent, exist_ok=True) # Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f: async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()) await f.write(
return file_list TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
async def fetch_file_list_with_retry( async def fetch_file_list_with_retry(
@@ -332,8 +353,28 @@ async def _download_file(
target_dir: Path, target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None, on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path: ) -> Path:
if await aios.path.exists(target_dir / path): target_path = target_dir / path
return target_dir / path
if await aios.path.exists(target_path):
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
try:
remote_size, _ = await file_meta(model_id, revision, path)
if local_size != remote_size:
logger.info(
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
)
await aios.remove(target_path)
else:
return target_path
except Exception as e:
# Offline or network error - trust local file
logger.debug(
f"Could not verify {path} against remote (offline?): {e}, using local file"
)
return target_path
await aios.makedirs((target_dir / path).parent, exist_ok=True) await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path) length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -542,17 +583,26 @@ async def download_shard(
async def on_progress_wrapper( async def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None: ) -> None:
start_time = ( previous_progress = file_progress.get(file.path)
file_progress[file.path].start_time
if file.path in file_progress # Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
else time.time() is_redownload = (
) previous_progress is not None
downloaded_this_session = ( and curr_bytes < previous_progress.downloaded.in_bytes
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
) )
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
speed = ( speed = (
downloaded_this_session / (time.time() - start_time) downloaded_this_session / (time.time() - start_time)
if time.time() - start_time > 0 if time.time() - start_time > 0

View File

View File

@@ -0,0 +1,451 @@
"""Tests for download verification and cache behavior."""
import time
from collections.abc import AsyncIterator
from datetime import timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from pydantic import TypeAdapter
from exo.download.download_utils import (
delete_model,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
"""Set up a temporary models directory for testing."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestFileVerification:
"""Tests for file size verification in _download_file."""
async def test_redownload_when_file_size_changes_upstream(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with mismatched sizes are re-downloaded."""
# Import inside test to allow patching
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file with wrong size
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content") # 13 bytes
remote_size = 1000 # Different from local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
# Set up mock HTTP response for re-download
mock_response = MagicMock()
mock_response.status = 200
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
side_effect=[b"x" * remote_size, b""]
)
mock_session = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
# Mock calc_hash to return the expected hash
with patch(
"exo.download.download_utils.calc_hash",
new_callable=AsyncMock,
return_value=remote_hash,
):
await _download_file(model_id, "main", "test.safetensors", target_dir)
# file_meta should be called twice: once for verification, once for download
assert mock_file_meta.call_count == 2
async def test_skip_download_when_file_size_matches(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with matching sizes are not re-downloaded."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
local_content = b"local content"
async with aiofiles.open(local_file, "wb") as f:
await f.write(local_content)
remote_size = len(local_content) # Same as local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return immediately without downloading
assert result == local_file
mock_file_meta.assert_called_once()
mock_session_factory.assert_not_called()
async def test_offline_fallback_uses_local_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that local files are used when network is unavailable."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content")
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return local file without attempting download
assert result == local_file
mock_session_factory.assert_not_called()
class TestFileListCache:
"""Tests for file list caching behavior."""
async def test_fetch_fresh_and_update_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that fresh data is fetched and cache is updated."""
models_dir = tmp_path / "models"
file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=100),
]
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=file_list,
) as mock_fetch,
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == file_list
mock_fetch.assert_called_once()
# Verify cache was written
cache_file = (
models_dir
/ "caches"
/ model_id.normalize()
/ f"{model_id.normalize()}--main--file_list.json"
)
assert await aios.path.exists(cache_file)
async with aiofiles.open(cache_file, "r") as f:
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
await f.read()
)
assert cached_data == file_list
async def test_fallback_to_cache_when_fetch_fails(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that cached data is used when fetch fails."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
# Create cache file
cached_file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
)
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_file_list
async def test_error_propagates_when_no_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that errors propagate when fetch fails and no cache exists."""
models_dir = tmp_path / "models"
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
pytest.raises(Exception, match="Network error"),
):
await fetch_file_list_with_cache(model_id, "main")
class TestModelDeletion:
"""Tests for model deletion including cache cleanup."""
async def test_delete_model_clears_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that deleting a model also deletes its cache."""
models_dir = tmp_path / "models"
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
# Create model and cache directories
await aios.makedirs(model_dir, exist_ok=True)
await aios.makedirs(cache_dir, exist_ok=True)
# Add some files
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
await f.write("model data")
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is True
assert not await aios.path.exists(model_dir)
assert not await aios.path.exists(cache_dir)
async def test_delete_model_only_cache_exists(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting when only cache exists (model already deleted)."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
# Only create cache directory
await aios.makedirs(cache_dir, exist_ok=True)
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
# Returns False because model dir didn't exist
assert result is False
# But cache should still be cleaned up
assert not await aios.path.exists(cache_dir)
async def test_delete_nonexistent_model(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting a model that doesn't exist."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is False
class TestProgressResetOnRedownload:
"""Tests for progress tracking when files are re-downloaded."""
async def test_progress_resets_correctly_on_redownload(
self, model_id: ModelId
) -> None:
"""Test that progress tracking resets when a file is re-downloaded.
When a file is deleted and re-downloaded (due to size mismatch),
the progress tracking should reset rather than calculating negative
downloaded_this_session values.
"""
# Simulate file_progress dict as it exists in download_shard
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with old file progress (simulating existing large file)
old_file_size = 1_500_000_000 # 1.5 GB
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(old_file_size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(old_file_size),
speed=0,
eta=timedelta(0),
status="not_started",
start_time=time.time() - 10, # Started 10 seconds ago
)
# Simulate the logic from on_progress_wrapper after re-download starts
# This is the exact logic from the fixed on_progress_wrapper
curr_bytes = 100_000 # 100 KB - new download just started
previous_progress = file_progress.get("model.safetensors")
# Detect re-download: curr_bytes < previous downloaded
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is True, "Should detect re-download scenario"
assert downloaded_this_session == curr_bytes, (
"downloaded_this_session should equal curr_bytes on re-download"
)
assert downloaded_this_session > 0, (
"downloaded_this_session should be positive, not negative"
)
# Calculate speed (should be positive)
elapsed = time.time() - start_time
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
assert speed >= 0, "Speed should be non-negative"
async def test_progress_accumulates_on_continuing_download(
self, model_id: ModelId
) -> None:
"""Test that progress accumulates correctly for continuing downloads.
When a download continues from where it left off (resume),
the progress should accumulate correctly.
"""
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with partial download progress
initial_downloaded = 500_000 # 500 KB already downloaded
start_time = time.time() - 5 # Started 5 seconds ago
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(initial_downloaded),
downloaded_this_session=Memory.from_bytes(initial_downloaded),
total=Memory.from_bytes(1_000_000),
speed=100_000,
eta=timedelta(seconds=5),
status="in_progress",
start_time=start_time,
)
# Progress callback with more bytes downloaded
curr_bytes = 600_000 # 600 KB - continuing download
previous_progress = file_progress.get("model.safetensors")
# This is NOT a re-download (curr_bytes > previous downloaded)
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
downloaded_this_session = curr_bytes
used_start_time = time.time()
else:
used_start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is False, (
"Should NOT detect re-download for continuing download"
)
assert used_start_time == start_time, "Should preserve original start_time"
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
assert downloaded_this_session == expected_session, (
f"Should accumulate: {downloaded_this_session} == {expected_session}"
)
assert downloaded_this_session == 600_000, (
"downloaded_this_session should equal total downloaded so far"
)

View File

@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
), ),
} }
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = { _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard( "flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"), model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312), storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, n_layers=57,
hidden_size=1, hidden_size=1,
@@ -428,7 +428,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0), storage_size=Memory.from_kb(0),
n_layers=12, n_layers=12,
can_shard=False, can_shard=False,
safetensors_index_filename=None, # Single file safetensors_index_filename=None,
), ),
ComponentInfo( ComponentInfo(
component_name="text_encoder_2", component_name="text_encoder_2",
@@ -442,7 +442,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer", component_name="transformer",
component_path="transformer/", component_path="transformer/",
storage_size=Memory.from_bytes(23782357120), storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks n_layers=57,
can_shard=True, can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json", safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
), ),
@@ -457,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
], ],
), ),
"flux1-dev": ModelCard( "flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"), model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312), storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, n_layers=57,
hidden_size=1, hidden_size=1,
@@ -470,7 +470,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0), storage_size=Memory.from_kb(0),
n_layers=12, n_layers=12,
can_shard=False, can_shard=False,
safetensors_index_filename=None, # Single file safetensors_index_filename=None,
), ),
ComponentInfo( ComponentInfo(
component_name="text_encoder_2", component_name="text_encoder_2",
@@ -484,7 +484,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer", component_name="transformer",
component_path="transformer/", component_path="transformer/",
storage_size=Memory.from_bytes(23802816640), storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks n_layers=57,
can_shard=True, can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json", safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
), ),
@@ -499,7 +499,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
], ],
), ),
"flux1-krea-dev": ModelCard( "flux1-krea-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"), model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57, n_layers=57,
hidden_size=1, hidden_size=1,
@@ -541,9 +541,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
], ],
), ),
"qwen-image": ModelCard( "qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"), model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176), storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style) n_layers=60,
hidden_size=1, hidden_size=1,
supports_tensor=False, supports_tensor=False,
tasks=[ModelTask.TextToImage], tasks=[ModelTask.TextToImage],
@@ -551,10 +551,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo( ComponentInfo(
component_name="text_encoder", component_name="text_encoder",
component_path="text_encoder/", component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312), storage_size=Memory.from_bytes(16584333312),
n_layers=12, n_layers=12,
can_shard=False, can_shard=False,
safetensors_index_filename=None, # Single file safetensors_index_filename=None,
), ),
ComponentInfo( ComponentInfo(
component_name="transformer", component_name="transformer",
@@ -575,9 +575,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
], ],
), ),
"qwen-image-edit-2509": ModelCard( "qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"), model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176), storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style) n_layers=60,
hidden_size=1, hidden_size=1,
supports_tensor=False, supports_tensor=False,
tasks=[ModelTask.ImageToImage], tasks=[ModelTask.ImageToImage],
@@ -585,10 +585,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo( ComponentInfo(
component_name="text_encoder", component_name="text_encoder",
component_path="text_encoder/", component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312), storage_size=Memory.from_bytes(16584333312),
n_layers=12, n_layers=12,
can_shard=False, can_shard=False,
safetensors_index_filename=None, # Single file safetensors_index_filename=None,
), ),
ComponentInfo( ComponentInfo(
component_name="transformer", component_name="transformer",
@@ -610,6 +610,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
), ),
} }
def _generate_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
# quantizations = [8, 6, 5, 4, 3]
quantizations = [8, 4]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = ModelId(base_card.model_id + f"-{quant}bit")
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS: if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS) MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -0,0 +1,12 @@
"""Shared types for MLX-related functionality."""
from collections.abc import Sequence
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -19,6 +19,8 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.gpt_oss import GptOssMoeModel from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel from mlx_lm.models.llama import Model as LlamaModel
@@ -145,6 +147,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None: if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType] cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output return output
@@ -252,10 +258,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None: if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits return logits
cls.__call__ = patched_call cls.__call__ = patched_call
@@ -334,15 +336,7 @@ def tensor_auto_parallel(
group=group, group=group,
) )
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)): if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy( tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group, group,
all_to_sharded_linear, all_to_sharded_linear,
@@ -351,7 +345,6 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place, sharded_to_all_linear_in_place,
) )
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)): elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy( tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group, group,
all_to_sharded_linear, all_to_sharded_linear,
@@ -367,6 +360,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place, all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place, sharded_to_all_linear_in_place,
) )
elif isinstance(model, GLM4MoeLiteModel):
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)): elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy( tensor_parallel_sharding_strategy = QwenShardingStrategy(
group, group,
@@ -441,7 +442,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj) layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj) layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model return model
@@ -516,6 +517,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group layer.mlp.sharding_group = self.group
mx.eval(layer)
return model return model
@@ -533,6 +536,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y return y
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers), # type: ignore
on_timeout,
)
if layer.self_attn.q_lora_rank is None: # type: ignore
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
return model
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy): class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model( def shard_model(
self, self,
@@ -566,7 +647,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
) )
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model return model
@@ -607,6 +688,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj) layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model return model
@@ -661,7 +743,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model return model

View File

@@ -1,39 +1,81 @@
# type: ignore import os
# TODO: Fix this file, including types!
from copy import deepcopy from copy import deepcopy
from typing import Callable from typing import Any, cast
import mlx.core as mx import mlx.core as mx
from mlx_lm import stream_generate from mlx_lm.models.cache import (
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache KVCache,
QuantizedKVCache,
RotatingKVCache,
trim_prompt_cache,
)
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.engines.mlx.utils_mlx import make_kv_cache from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.85
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache: class KVPrefixCache:
def __init__(self): def __init__(self, tokenizer: TokenizerWrapper):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints) self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = [] self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
def add_kv_cache( def clear(self):
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache] """Clear all cached prompts and caches."""
): self.prompts.clear()
tokenized_prompt = self.encode_prompt(tokenizer, prompt) self.caches.clear()
self._last_used.clear()
def add_kv_cache(self, prompt: str, cache: KVCacheType):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts.append(tokenized_prompt) self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache)) self.caches.append(deepcopy(cache))
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
def update_kv_cache(
self,
index: int,
prompt: str,
cache: KVCacheType,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts[index] = tokenized_prompt
self.caches[index] = deepcopy(cache)
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
def get_kv_cache( def get_kv_cache(
self, self,
model: Model, model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str, prompt: str,
) -> list[_BaseCache]: ) -> tuple[KVCacheType, mx.array, int | None]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt) """Get KV cache for prompt, returning remaining tokens to prefill.
Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
"""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
max_length = len(tokenized_prompt) max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0 best_snapshot_index, best_snapshot_length = None, 0
@@ -42,63 +84,127 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt) length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length: if length == max_length:
return self.caches[i] # Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
if length > best_snapshot_length: if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None: if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index]) new_tokens = max_length - best_snapshot_length
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length) logger.info(
tokenized_prompt = tokenized_prompt[best_snapshot_index:] f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
) )
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache) prompt_cache = deepcopy(self.caches[best_snapshot_index])
return prompt_cache # Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array: self._access_counter += 1
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( self._last_used[best_snapshot_index] = self._access_counter
tokenizer.bos_token remaining_tokens = tokenized_prompt[best_snapshot_length:]
) return prompt_cache, remaining_tokens, best_snapshot_index
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens else:
) prompt_cache = make_kv_cache(model)
return mx.array(tokenized_prompt) if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int: def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE) """Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0: if n == 0:
return 0 return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32) equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item()) return int(mx.sum(prefix_mask).item())
def prefill( def make_kv_cache(
model: Model, model: Model, max_kv_size: int | None = None, keep: int = 0
tokenizer: TokenizerWrapper, ) -> KVCacheType:
sampler: Callable[[mx.array], mx.array], assert hasattr(model, "layers")
prompt: mx.array,
cache: list[_BaseCache], # TODO: Do this for all models
) -> None: if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
for _ in stream_generate( logger.info("Using MLX LM's make cache")
model=model, return model.make_cache() # type: ignore
tokenizer=tokenizer,
prompt=prompt, if max_kv_size is None:
max_tokens=0, if KV_CACHE_BITS is None:
sampler=sampler, logger.info("Using default KV cache")
prompt_cache=cache, return [KVCache() for _ in model.layers]
prefill_step_size=2048, else:
kv_group_size=KV_GROUP_SIZE, logger.info("Using quantized KV cache")
kv_bits=KV_BITS, return [
): QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
pass for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]

View File

@@ -4,7 +4,7 @@
KV_GROUP_SIZE: int | None = 32 KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4 ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192 MAX_TOKENS: int = 32168
MAX_KV_SIZE: int | None = 3200 MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600 KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine" QUANTIZE_MODEL_MODE: str | None = "affine"

View File

@@ -1,12 +1,12 @@
import time
from typing import Any, Callable, Generator, cast, get_args from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx import mlx.core as mx
from mlx_lm.generate import stream_generate from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import ( from exo.shared.types.api import (
BenchChatCompletionTaskParams, BenchChatCompletionTaskParams,
ChatCompletionMessage, ChatCompletionMessage,
@@ -14,35 +14,78 @@ from exo.shared.types.api import (
GenerationStats, GenerationStats,
) )
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import ( from exo.shared.types.worker.runner_response import (
GenerationResponse, GenerationResponse,
) )
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import ( from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template, apply_chat_template,
make_kv_cache,
mx_barrier, mx_barrier,
) )
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device()) generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any], def prefill(
quantized_kv_start: int, model: Model,
kv_group_size: int, tokenizer: TokenizerWrapper,
kv_bits: int | None, sampler: Callable[[mx.array], mx.array],
) -> None: prompt_tokens: mx.array,
if kv_bits is None: cache: KVCacheType,
return ) -> float:
for e, c in enumerate(prompt_cache): """Prefill the KV cache with prompt tokens.
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore This runs the model over the prompt tokens to populate the cache,
): then trims off the extra generated token.
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def warmup_inference( def warmup_inference(
@@ -120,6 +163,7 @@ def mlx_generate(
tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams, task: ChatCompletionTaskParams,
prompt: str, prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]: ) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation # Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory() mx.reset_peak_memory()
@@ -131,7 +175,22 @@ def mlx_generate(
if task.seed is not None: if task.seed is not None:
mx.random.seed(task.seed) mx.random.seed(task.seed)
caches = make_kv_cache(model=model) # Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
prefix_hit_length = 0
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = [] logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench: if is_bench:
@@ -144,11 +203,19 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0, top_p=task.top_p if task.top_p is not None else 1.0,
) )
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate( for out in stream_generate(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
prompt=prompt, prompt=last_token,
max_tokens=max_tokens, max_tokens=max_tokens,
sampler=sampler, sampler=sampler,
logits_processors=logits_processors, logits_processors=logits_processors,
@@ -158,12 +225,13 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE, kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS, kv_bits=KV_BITS,
): ):
generated_text_parts.append(out.text)
logger.info(out.text) logger.info(out.text)
stats: GenerationStats | None = None stats: GenerationStats | None = None
if out.finish_reason is not None: if out.finish_reason is not None:
stats = GenerationStats( stats = GenerationStats(
prompt_tps=float(out.prompt_tps), prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps), generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens), prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens), generation_tokens=int(out.generation_tokens),
@@ -185,6 +253,26 @@ def mlx_generate(
) )
if out.finish_reason is not None: if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
break break
# TODO: Do we want an mx_barrier? # TODO: Do we want an mx_barrier?

View File

@@ -18,15 +18,12 @@ try:
except ImportError: except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import ( from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TRUST_REMOTE_CODE, TRUST_REMOTE_CODE,
) )
@@ -405,7 +402,11 @@ def apply_chat_template(
continue continue
message.content = "\n".join(c.text for c in message.content).strip() message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None: if (
message.content is None
and message.thinking is None
and message.tool_calls is None
):
continue continue
# Null values are not valid when applying templates in tokenizer # Null values are not valid when applying templates in tokenizer
@@ -462,31 +463,6 @@ class NullKVCache(KVCache):
raise NotImplementedError("We should not be setting a NullKVCache.") raise NotImplementedError("We should not be setting a NullKVCache.")
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
def mlx_force_oom(size: int = 40000) -> None: def mlx_force_oom(size: int = 40000) -> None:
""" """
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations. Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.

View File

@@ -70,6 +70,7 @@ from exo.worker.engines.image import (
warmup_image_generator, warmup_image_generator,
) )
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import ( from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template, apply_chat_template,
@@ -103,6 +104,7 @@ def main(
model: Model | DistributedImageModel | None = None model: Model | DistributedImageModel | None = None
tokenizer = None tokenizer = None
group = None group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle() current_status: RunnerStatus = RunnerIdle()
logger.info("runner created") logger.info("runner created")
@@ -161,6 +163,8 @@ def main(
logger.info( logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}" f"model has_tool_calling={tokenizer.has_tool_calling}"
) )
kv_prefix_cache = KVPrefixCache(tokenizer)
elif ( elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks or ModelTask.ImageToImage in shard_metadata.model_card.tasks
@@ -170,7 +174,6 @@ def main(
raise ValueError( raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}" f"Unknown model task(s): {shard_metadata.model_card.tasks}"
) )
current_status = RunnerLoaded() current_status = RunnerLoaded()
logger.info("runner loaded") logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded): case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -238,6 +241,7 @@ def main(
tokenizer=tokenizer, tokenizer=tokenizer,
task=task_params, task=task_params,
prompt=prompt, prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
) )
# For other thinking models (GLM, etc.), check if we need to # For other thinking models (GLM, etc.), check if we need to

View File

@@ -0,0 +1,545 @@
# type: ignore
import time
from typing import cast
from unittest.mock import patch
import mlx.core as mx
import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
encode_prompt,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
DEFAULT_GPT_OSS_MODEL_ID,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
@pytest.fixture
def mock_tokenizer(self):
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
from unittest.mock import MagicMock
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.clear()
assert len(cache.prompts) == 0
def _load_gpt_oss() -> tuple[Model, object]:
from mlx_lm.utils import load_model
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
model, _ = load_model(model_path, lazy=False)
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
return cast(Model, model), tokenizer
@pytest.mark.slow
@pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
)
class TestKVPrefixCacheWithModel:
@pytest.fixture(scope="class")
def model_and_tokenizer(self):
model, tokenizer = _load_gpt_oss()
return model, tokenizer
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, long_prompt
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
):
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra_keys, extra_values)
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
):
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra = mx.random.normal((1, num_heads, 30, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# Consume the entire generator so the cache-saving code after yield runs
generated_tokens = 0
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
generated_tokens += 1
assert len(kv_prefix_cache.prompts) == 1
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
assert len(kv_prefix_cache.prompts) == 1
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
assert len(prompt1_tokens) > 1000, (
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
)
# First generation populates the cache (must prefill all tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
):
pass
second_gen_time = time.perf_counter() - t0
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
assert second_gen_time < first_gen_time * 0.5, (
f"Expected prefix cache speedup: "
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
)
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
assert len(kv_prefix_cache.prompts) == 3
# Access the third entry to make it most recently used
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
):
# Trigger eviction by adding a new entry
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

16
uv.lock generated
View File

@@ -415,7 +415,7 @@ requires-dist = [
{ name = "mflux", specifier = "==0.15.4" }, { name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" }, { name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" }, { name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" }, { name = "mlx-lm", specifier = "==0.30.5" },
{ name = "openai-harmony", specifier = ">=0.0.8" }, { name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" }, { name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" }, { name = "psutil", specifier = ">=7.0.0" },
@@ -1072,8 +1072,8 @@ wheels = [
[[package]] [[package]]
name = "mlx-lm" name = "mlx-lm"
version = "0.30.4" version = "0.30.5"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" }, { name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1083,6 +1083,10 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
]
[[package]] [[package]]
name = "mlx-metal" name = "mlx-metal"
@@ -2281,7 +2285,7 @@ wheels = [
[[package]] [[package]]
name = "transformers" name = "transformers"
version = "5.0.0rc2" version = "5.0.0rc3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2296,9 +2300,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" } sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" }, { url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
] ]
[[package]] [[package]]