mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 07:20:14 -05:00
Compare commits
10 Commits
runner-can
...
improve-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93dab5b960 | ||
|
|
bd4f0bf048 | ||
|
|
cd8c01b7c8 | ||
|
|
59e991ce15 | ||
|
|
ffba340e70 | ||
|
|
9968abe816 | ||
|
|
0e30b0830f | ||
|
|
44453c4c8b | ||
|
|
1290e8ed9f | ||
|
|
d93db3d6bf |
@@ -5,18 +5,18 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
|
||||
@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -3,12 +3,28 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to ONLY include package.json and package-lock.json
|
||||
# This ensures prettier-svelte only rebuilds when lockfiles change
|
||||
dashboardLockfileSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
isDashboardDir = baseName == "dashboard" && type == "directory";
|
||||
isPackageFile =
|
||||
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
&& (baseName == "package.json" || baseName == "package-lock.json");
|
||||
in
|
||||
isDashboardDir || isPackageFile;
|
||||
};
|
||||
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
|
||||
@@ -121,11 +121,20 @@ async def ensure_models_dir() -> Path:
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
models_dir = await ensure_models_dir()
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
deleted = False
|
||||
if await aios.path.exists(model_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
deleted = True
|
||||
|
||||
# Also clear cache
|
||||
if await aios.path.exists(cache_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
async def seed_models(seed_dir: str | Path):
|
||||
@@ -151,16 +160,28 @@ async def fetch_file_list_with_cache(
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
return file_list
|
||||
|
||||
# Always try fresh first
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
# Update cache with fresh data
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# No cache available, propagate the error
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
@@ -332,8 +353,28 @@ async def _download_file(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
) -> Path:
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
try:
|
||||
remote_size, _ = await file_meta(model_id, revision, path)
|
||||
if local_size != remote_size:
|
||||
logger.info(
|
||||
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
|
||||
)
|
||||
await aios.remove(target_path)
|
||||
else:
|
||||
return target_path
|
||||
except Exception as e:
|
||||
# Offline or network error - trust local file
|
||||
logger.debug(
|
||||
f"Could not verify {path} against remote (offline?): {e}, using local file"
|
||||
)
|
||||
return target_path
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -542,17 +583,26 @@ async def download_shard(
|
||||
async def on_progress_wrapper(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
start_time = (
|
||||
file_progress[file.path].start_time
|
||||
if file.path in file_progress
|
||||
else time.time()
|
||||
)
|
||||
downloaded_this_session = (
|
||||
file_progress[file.path].downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
|
||||
if file.path in file_progress
|
||||
else curr_bytes
|
||||
previous_progress = file_progress.get(file.path)
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
speed = (
|
||||
downloaded_this_session / (time.time() - start_time)
|
||||
if time.time() - start_time > 0
|
||||
|
||||
0
src/exo/download/tests/__init__.py
Normal file
0
src/exo/download/tests/__init__.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""Tests for download verification and cache behavior."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from exo.download.download_utils import (
|
||||
delete_model,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
"""Set up a temporary models directory for testing."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestFileVerification:
|
||||
"""Tests for file size verification in _download_file."""
|
||||
|
||||
async def test_redownload_when_file_size_changes_upstream(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with mismatched sizes are re-downloaded."""
|
||||
# Import inside test to allow patching
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file with wrong size
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content") # 13 bytes
|
||||
|
||||
remote_size = 1000 # Different from local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
# Set up mock HTTP response for re-download
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
|
||||
side_effect=[b"x" * remote_size, b""]
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Mock calc_hash to return the expected hash
|
||||
with patch(
|
||||
"exo.download.download_utils.calc_hash",
|
||||
new_callable=AsyncMock,
|
||||
return_value=remote_hash,
|
||||
):
|
||||
await _download_file(model_id, "main", "test.safetensors", target_dir)
|
||||
|
||||
# file_meta should be called twice: once for verification, once for download
|
||||
assert mock_file_meta.call_count == 2
|
||||
|
||||
async def test_skip_download_when_file_size_matches(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with matching sizes are not re-downloaded."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
local_content = b"local content"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(local_content)
|
||||
|
||||
remote_size = len(local_content) # Same as local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return immediately without downloading
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_called_once()
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
async def test_offline_fallback_uses_local_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that local files are used when network is unavailable."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return local file without attempting download
|
||||
assert result == local_file
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFileListCache:
|
||||
"""Tests for file list caching behavior."""
|
||||
|
||||
async def test_fetch_fresh_and_update_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that fresh data is fetched and cache is updated."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=100),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=file_list,
|
||||
) as mock_fetch,
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == file_list
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
# Verify cache was written
|
||||
cache_file = (
|
||||
models_dir
|
||||
/ "caches"
|
||||
/ model_id.normalize()
|
||||
/ f"{model_id.normalize()}--main--file_list.json"
|
||||
)
|
||||
assert await aios.path.exists(cache_file)
|
||||
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
|
||||
await f.read()
|
||||
)
|
||||
assert cached_data == file_list
|
||||
|
||||
async def test_fallback_to_cache_when_fetch_fails(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that cached data is used when fetch fails."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create cache file
|
||||
cached_file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == cached_file_list
|
||||
|
||||
async def test_error_propagates_when_no_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that errors propagate when fetch fails and no cache exists."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
pytest.raises(Exception, match="Network error"),
|
||||
):
|
||||
await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
|
||||
class TestModelDeletion:
|
||||
"""Tests for model deletion including cache cleanup."""
|
||||
|
||||
async def test_delete_model_clears_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that deleting a model also deletes its cache."""
|
||||
models_dir = tmp_path / "models"
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Create model and cache directories
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Add some files
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
|
||||
await f.write("model data")
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is True
|
||||
assert not await aios.path.exists(model_dir)
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_model_only_cache_exists(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting when only cache exists (model already deleted)."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Only create cache directory
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
# Returns False because model dir didn't exist
|
||||
assert result is False
|
||||
# But cache should still be cleaned up
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_nonexistent_model(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a model that doesn't exist."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestProgressResetOnRedownload:
|
||||
"""Tests for progress tracking when files are re-downloaded."""
|
||||
|
||||
async def test_progress_resets_correctly_on_redownload(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress tracking resets when a file is re-downloaded.
|
||||
|
||||
When a file is deleted and re-downloaded (due to size mismatch),
|
||||
the progress tracking should reset rather than calculating negative
|
||||
downloaded_this_session values.
|
||||
"""
|
||||
# Simulate file_progress dict as it exists in download_shard
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with old file progress (simulating existing large file)
|
||||
old_file_size = 1_500_000_000 # 1.5 GB
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(old_file_size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(old_file_size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="not_started",
|
||||
start_time=time.time() - 10, # Started 10 seconds ago
|
||||
)
|
||||
|
||||
# Simulate the logic from on_progress_wrapper after re-download starts
|
||||
# This is the exact logic from the fixed on_progress_wrapper
|
||||
curr_bytes = 100_000 # 100 KB - new download just started
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is True, "Should detect re-download scenario"
|
||||
assert downloaded_this_session == curr_bytes, (
|
||||
"downloaded_this_session should equal curr_bytes on re-download"
|
||||
)
|
||||
assert downloaded_this_session > 0, (
|
||||
"downloaded_this_session should be positive, not negative"
|
||||
)
|
||||
|
||||
# Calculate speed (should be positive)
|
||||
elapsed = time.time() - start_time
|
||||
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
|
||||
assert speed >= 0, "Speed should be non-negative"
|
||||
|
||||
async def test_progress_accumulates_on_continuing_download(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress accumulates correctly for continuing downloads.
|
||||
|
||||
When a download continues from where it left off (resume),
|
||||
the progress should accumulate correctly.
|
||||
"""
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with partial download progress
|
||||
initial_downloaded = 500_000 # 500 KB already downloaded
|
||||
start_time = time.time() - 5 # Started 5 seconds ago
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(initial_downloaded),
|
||||
downloaded_this_session=Memory.from_bytes(initial_downloaded),
|
||||
total=Memory.from_bytes(1_000_000),
|
||||
speed=100_000,
|
||||
eta=timedelta(seconds=5),
|
||||
status="in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# Progress callback with more bytes downloaded
|
||||
curr_bytes = 600_000 # 600 KB - continuing download
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# This is NOT a re-download (curr_bytes > previous downloaded)
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
downloaded_this_session = curr_bytes
|
||||
used_start_time = time.time()
|
||||
else:
|
||||
used_start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is False, (
|
||||
"Should NOT detect re-download for continuing download"
|
||||
)
|
||||
assert used_start_time == start_time, "Should preserve original start_time"
|
||||
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
|
||||
assert downloaded_this_session == expected_session, (
|
||||
f"Should accumulate: {downloaded_this_session} == {expected_session}"
|
||||
)
|
||||
assert downloaded_this_session == 600_000, (
|
||||
"downloaded_this_session should equal total downloaded so far"
|
||||
)
|
||||
@@ -88,7 +88,6 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -509,14 +508,16 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
@@ -900,11 +901,6 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -986,11 +982,6 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
|
||||
@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -248,7 +246,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -260,7 +258,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -270,7 +268,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -280,18 +278,6 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -300,9 +286,10 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -20,15 +20,9 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -186,7 +180,6 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -202,18 +195,6 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -428,7 +428,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -442,7 +442,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -457,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -470,7 +470,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -484,7 +484,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -499,7 +499,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -541,9 +541,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -551,10 +551,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -575,9 +575,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -585,10 +585,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -610,6 +610,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -88,7 +84,6 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
12
src/exo/shared/types/mlx.py
Normal file
12
src/exo/shared/types/mlx.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -61,10 +60,6 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -92,7 +87,6 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -349,13 +349,8 @@ class InfoGatherer:
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
@@ -365,15 +360,12 @@ class InfoGatherer:
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
@@ -398,22 +390,17 @@ class InfoGatherer:
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = await get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
@@ -145,6 +147,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -252,10 +258,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
@@ -334,15 +336,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -351,7 +345,6 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -367,6 +360,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GLM4MoeLiteModel):
|
||||
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
@@ -441,7 +442,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -516,6 +517,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -533,6 +536,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers), # type: ignore
|
||||
on_timeout,
|
||||
)
|
||||
if layer.self_attn.q_lora_rank is None: # type: ignore
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -566,7 +647,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -607,6 +688,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -661,7 +743,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,81 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
cache: KVCacheType,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens, matched_index) where:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
@@ -42,63 +84,127 @@ class KVPrefixCache:
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
return prompt_cache
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while len(self.caches) > 0:
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
)
|
||||
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
KV_GROUP_SIZE: int | None = 32
|
||||
KV_BITS: int | None = None
|
||||
ATTENTION_KV_BITS: int | None = 4
|
||||
MAX_TOKENS: int = 8192
|
||||
MAX_TOKENS: int = 32168
|
||||
MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -14,34 +14,78 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -89,6 +133,10 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -115,6 +163,7 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -126,7 +175,22 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
prefix_hit_length = 0
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -139,11 +203,19 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
@@ -153,12 +225,13 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
@@ -180,4 +253,26 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -18,15 +18,12 @@ try:
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -70,6 +67,8 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -87,6 +86,30 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -379,7 +402,11 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
message.content = "\n".join(c.text for c in message.content).strip()
|
||||
if message.content is None and message.thinking is None:
|
||||
if (
|
||||
message.content is None
|
||||
and message.thinking is None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
@@ -436,31 +463,6 @@ class NullKVCache(KVCache):
|
||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
@@ -510,23 +512,3 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -116,9 +115,8 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -222,22 +220,15 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(cancelled_task_id=cancelled_task_id):
|
||||
await self.runners[self._task_to_runner_id(task)].cancel_task(
|
||||
cancelled_task_id
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -360,6 +351,8 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
@@ -60,8 +59,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _cancel_tasks(runners, tasks)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -272,7 +270,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -286,7 +284,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -294,31 +292,16 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id, cancelled_task_id=task.task_id
|
||||
)
|
||||
|
||||
@@ -3,10 +3,11 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.tests.patches import load_null_model
|
||||
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
@@ -15,8 +16,9 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
*,
|
||||
_load_null_models: bool = False,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -30,6 +32,13 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
p = None
|
||||
if _load_null_models:
|
||||
from unittest.mock import patch
|
||||
|
||||
p = patch("mlx_lm.utils.load_model", new=load_null_model)
|
||||
p.start()
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -39,7 +48,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
@@ -53,6 +62,8 @@ def entrypoint(
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if p is not None:
|
||||
p.stop()
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
|
||||
@@ -37,7 +37,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -71,6 +70,7 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -78,7 +78,6 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -87,7 +86,6 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -102,13 +100,11 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -117,7 +113,6 @@ def main(
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -162,25 +157,28 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -191,11 +189,11 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert inference_model
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
@@ -207,8 +205,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -227,7 +225,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert inference_model
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
@@ -239,10 +237,11 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -262,11 +261,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
inference_model, GptOssModel
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -278,19 +277,7 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
cancel_every = 5
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if (
|
||||
@@ -357,7 +344,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -371,9 +358,7 @@ def main(
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
@@ -420,7 +405,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -432,9 +417,7 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
@@ -497,7 +480,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
@@ -49,12 +49,10 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,8 +63,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -74,7 +72,6 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -89,7 +86,6 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -97,41 +93,37 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with self._tg as tg:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
@@ -139,7 +131,6 @@ class RunnerSupervisor:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -149,13 +140,7 @@ class RunnerSupervisor:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -221,4 +206,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
await self.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
50
src/exo/worker/tests/patches.py
Normal file
50
src/exo/worker/tests/patches.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# type: ignore
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.mlx import Model
|
||||
|
||||
|
||||
def load_null_model(path: Path, **_: object) -> "tuple[Model, dict[str, Any]]":
|
||||
with open(path / "config.json", "r") as f:
|
||||
cfg = json.load(f)
|
||||
model, args = _get_classes(cfg)
|
||||
model = model(args.from_dict(cfg))
|
||||
return model, cfg
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
||||
try:
|
||||
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
||||
except ImportError:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
raise ValueError(msg) from None
|
||||
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama",
|
||||
"llava": "mistral3",
|
||||
"phi-msft": "phixtral",
|
||||
"falcon_mamba": "mamba",
|
||||
"kimi_k2": "deepseek_v3",
|
||||
"qwen2_5_vl": "qwen2_vl",
|
||||
"minimax_m2": "minimax",
|
||||
"iquestcoder": "llama",
|
||||
}
|
||||
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# type: ignore
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
encode_prompt,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
DEFAULT_GPT_OSS_MODEL_ID,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
|
||||
def _load_gpt_oss() -> tuple[Model, object]:
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
||||
|
||||
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
|
||||
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
|
||||
|
||||
model, _ = load_model(model_path, lazy=False)
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
)
|
||||
class TestKVPrefixCacheWithModel:
|
||||
@pytest.fixture(scope="class")
|
||||
def model_and_tokenizer(self):
|
||||
model, tokenizer = _load_gpt_oss()
|
||||
return model, tokenizer
|
||||
|
||||
def test_prefill_populates_cache(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Test exact")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
short_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hi")],
|
||||
max_tokens=1,
|
||||
)
|
||||
short_prompt = apply_chat_template(tokenizer, short_task)
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hi there, how are you?")
|
||||
],
|
||||
max_tokens=1,
|
||||
)
|
||||
long_prompt = apply_chat_template(tokenizer, long_task)
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra_keys, extra_values)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra = mx.random.normal((1, num_heads, 30, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# Consume the entire generator so the cache-saving code after yield runs
|
||||
generated_tokens = 0
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
repeats = (1200 // len(base_tokens)) + 2
|
||||
long_content = base_text * repeats
|
||||
|
||||
task1 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=long_content)],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt1 = apply_chat_template(tokenizer, task1)
|
||||
prompt1_tokens = encode_prompt(tokenizer, prompt1)
|
||||
assert len(prompt1_tokens) > 1000, (
|
||||
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
|
||||
)
|
||||
|
||||
# First generation populates the cache (must prefill all tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content=long_content),
|
||||
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
|
||||
ChatCompletionMessage(role="user", content="Tell me more."),
|
||||
],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt2 = apply_chat_template(tokenizer, task2)
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
|
||||
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
|
||||
assert second_gen_time < first_gen_time * 0.5, (
|
||||
f"Expected prefix cache speedup: "
|
||||
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
|
||||
)
|
||||
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
for i, content in enumerate(prompts):
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=content)],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 3
|
||||
|
||||
# Access the third entry to make it most recently used
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="New entry")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI
|
||||
@@ -11,16 +10,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,18 +31,17 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.runners import RunnerFailed, RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
MODEL_CARDS = {"haha": MODEL_CARDS["qwen3-coder-480b-a35b-8bit"]}
|
||||
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
@@ -56,16 +50,14 @@ logger_setup(None)
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
cfg.bind = "0.0.0.0:8000"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/run_test")(run_test)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
@@ -87,28 +79,7 @@ async def tb_detection():
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
async def run_test(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
@@ -117,10 +88,30 @@ async def ring_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
async def run():
|
||||
for card in MODEL_CARDS.values():
|
||||
for instance in (
|
||||
ring_instance(test, card.model_id, iid, hn),
|
||||
jaccl_instance(test, card.model_id, iid),
|
||||
):
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, RunnerFailed
|
||||
):
|
||||
return
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
def ring_instance(test: Tests, model_id: ModelId, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
@@ -135,13 +126,13 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
@@ -163,7 +154,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
@@ -171,60 +162,33 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
for card in MODEL_CARDS.values():
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=card.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
return recv
|
||||
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
def jaccl_instance(test: Tests, model_id: ModelId, iid: InstanceId):
|
||||
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
@@ -235,7 +199,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
@@ -270,6 +234,7 @@ def new_runner(
|
||||
task_recv,
|
||||
logger,
|
||||
),
|
||||
kwargs={"_load_null_models": True},
|
||||
)
|
||||
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
|
||||
runner_process.start()
|
||||
|
||||
@@ -6,19 +6,8 @@ query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
if [[ $# -lt 1 ]]; then
|
||||
echo "USAGE: $0 [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -34,23 +23,12 @@ done
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:8000/run_test" \
|
||||
-H "Content-Type: application/json" -d "{\"devs\": ${devs}}" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
16
uv.lock
generated
16
uv.lock
generated
@@ -415,7 +415,7 @@ requires-dist = [
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -1072,8 +1072,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.4"
|
||||
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1083,6 +1083,10 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367, upload-time = "2026-01-25T15:29:30.125Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716, upload-time = "2026-01-25T15:29:28.29Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
@@ -2281,7 +2285,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "5.0.0rc2"
|
||||
version = "5.0.0rc3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2296,9 +2300,9 @@ dependencies = [
|
||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/94/e2/86b1bd5264272953370a5e50a91da38d7a53a87c5faf3fd3ff62d7353879/transformers-5.0.0rc2.tar.gz", hash = "sha256:9f2fa5e132433dd7eb910dc224b32de0baf758f3b6ffc918dbb632e0af85c07a", size = 8362532, upload-time = "2026-01-07T16:58:02.603Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520, upload-time = "2026-01-14T16:49:02.952Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/eb/9526a77354a2126f5b220f4792dc8494d573773c098dac6a5ad1fc7a5f17/transformers-5.0.0rc2-py3-none-any.whl", hash = "sha256:f8f2a14060ab11f20a0eec39d827af54c1589c327c5799d82808ae3f4167418a", size = 10067329, upload-time = "2026-01-07T16:57:59.617Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087, upload-time = "2026-01-14T16:48:59.393Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user