Fix tests!

This commit is contained in:
Matt Beton
2025-07-22 15:20:32 +01:00
committed by GitHub
parent 5adad08e09
commit 53c652c307
14 changed files with 127 additions and 91 deletions

View File

@@ -1,28 +1,26 @@
# type: ignore
import asyncio
import concurrent.futures
import os
from asyncio import AbstractEventLoop
from typing import Callable
from typing import Any, Callable
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer
from mlx_lm.utils import load_model
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer # type: ignore
from mlx_lm.utils import load_model # type: ignore
from pydantic import RootModel
from engines.mlx.auto_parallel import auto_parallel
from shared.types.tasks.common import ChatCompletionTaskParams
from shared.types.worker.mlx import Host
from shared.types.worker.shards import ShardMeta
from shared.types.worker.shards import ShardMetadata
from worker.download.download_utils import build_model_path
from worker.runner.communication import runner_print
def mx_barrier():
mx.eval(
mx.eval( # type: ignore
mx.distributed.all_sum(
mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu))
)
@@ -35,7 +33,7 @@ class HostList(RootModel[list[str]]):
return cls(root=[str(host) for host in hosts])
def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: # type: ignore
"""
Initialize the MLX distributed (runs in thread pool)
"""
@@ -62,7 +60,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
def initialize_mlx(
model_shard_meta: ShardMeta,
model_shard_meta: ShardMetadata,
hosts: list[Host],
) -> tuple[nn.Module, TokenizerWrapper, Callable[[mx.array], mx.array]]:
"""
@@ -71,19 +69,22 @@ def initialize_mlx(
mx.random.seed(42)
if len(hosts) > 1:
mlx_distributed_init(model_shard_meta.device_rank, hosts)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7) # type: ignore
model, tokenizer = shard_and_load(model_shard_meta)
return model, tokenizer, sampler
def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWrapper]:
runner_print(f"loading model from {model_shard_meta.model_path}")
def shard_and_load(model_shard_meta: ShardMetadata) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(model_shard_meta.model_meta.model_id)
model, config = load_model(model_shard_meta.model_path, lazy=True, strict=False)
runner_print(f"loading model from {model_path}")
tokenizer = load_tokenizer(model_shard_meta.model_path)
model, _ = load_model(model_path, lazy=True, strict=False) # type: ignore
assert isinstance(model, nn.Module)
tokenizer = load_tokenizer(model_path)
assert isinstance(tokenizer, TokenizerWrapper)
model = auto_parallel(model, model_shard_meta)
@@ -107,18 +108,18 @@ async def apply_chat_template(
# Filter out None values, keeping only 'role' and 'content' keys
formatted_messages = []
for message in messages_dicts:
filtered_message = {k: v for k, v in message.items() if v is not None}
filtered_message: dict[str, Any] = {k: v for k, v in message.items() if v is not None} # type: ignore
# Verify we have exactly the expected keys
assert set(filtered_message.keys()) == {"role", "content"}, (
f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}"
)
formatted_messages.append(filtered_message)
formatted_messages.append(filtered_message) # type: ignore
messages_dicts = formatted_messages
prompt: str = await loop.run_in_executor(
executor=mlx_executor,
func=lambda: tokenizer.apply_chat_template(
func=lambda: tokenizer.apply_chat_template( # type: ignore
messages_dicts,
tokenize=False,
add_generation_prompt=True,

View File

@@ -7,10 +7,9 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence
from fastapi.responses import Response, StreamingResponse
from shared.event_loops.commands import ExternalCommand
from shared.types.events.components import Apply, EventFromEventLog
from shared.types.events.registry import Event
from shared.types.events.components import EventFromEventLog
from shared.types.state import State
from shared.types.events.components import Apply
class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):

View File

@@ -8,13 +8,13 @@ from typing import (
if TYPE_CHECKING:
pass
from pydantic import BaseModel, Field, model_validator
from typing import Callable
from pydantic import BaseModel, Field, model_validator
from shared.types.common import NodeId
from shared.types.state import State
from shared.types.events.registry import Event
from shared.types.state import State
class EventFromEventLog[T: Event](BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from typing import Annotated, Generic, Literal, TypeAlias, TypeVar
from typing import Annotated, Generic, Literal, TypeVar
from pydantic import BaseModel, Field, TypeAdapter
from shared.types.common import NodeId
from shared.types.models import ModelId
from shared.types.models import ModelId, ModelMetadata
class PartitionStrategy(str, Enum):
@@ -20,10 +20,10 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
partition_strategy: PartitionStrategyT
device_rank: int
world_size: int
model_id: ModelId
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
@@ -47,7 +47,7 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
return self.end_layer == self.n_layers - 1
def __hash__(self) -> int:
return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
return hash((self.model_meta.model_id, self.start_layer, self.end_layer, self.n_layers))
ShardMetadata = Annotated[
@@ -57,17 +57,6 @@ ShardMetadataParser: TypeAdapter[ShardMetadata] = TypeAdapter(
ShardMetadata
)
# ---------------------------------------------------------------------------
# Convenience aliases
# ---------------------------------------------------------------------------
# "ShardMeta" is a widely-used alias for the concrete, fully-parameterised
# `ShardMetadata` type. Defining it here avoids repetitive generic
# parameters at call-sites and resolves unknown-import diagnostics in
# downstream modules.
ShardMeta: TypeAlias = ShardMetadata
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
"""

View File

@@ -0,0 +1,38 @@
from pathlib import Path
import pytest
from shared.types.models import ModelMetadata
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.model_meta import _get_model_meta # type: ignore
@pytest.fixture
def model_meta() -> ModelMetadata:
return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # type: ignore
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path):
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = 16
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta

View File

@@ -293,10 +293,10 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
try:
weight_map = await get_weight_map(str(shard.model_id))
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
print(f"Error getting weight map for {shard.model_id=}")
print(f"Error getting weight map for {shard.model_meta.model_id=}")
traceback.print_exc()
return ["*"]
@@ -360,27 +360,27 @@ async def download_shard(shard: ShardMetadata,
skip_download: bool = False,
allow_patterns: List[str] | None = None) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
print(f"Downloading {shard.model_id=}")
print(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_id)):
print(f"Using local model path {shard.model_id}")
local_path = Path(str(shard.model_id))
return local_path, await download_progress_for_local_path(str(shard.model_id), shard, local_path)
if await aios.path.exists(str(shard.model_meta.model_id)):
print(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(str(shard.model_meta.model_id), shard, local_path)
revision = "main"
target_dir = await ensure_models_dir()/str(shard.model_id).replace("/", "--")
target_dir = await ensure_models_dir()/str(shard.model_meta.model_id).replace("/", "--")
if not skip_download:
await aios.makedirs(target_dir, exist_ok=True)
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
print(f"Downloading {shard.model_id=} with {allow_patterns=}")
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
file_list = await fetch_file_list_with_cache(str(shard.model_id), revision, recursive=False)
file_list = await fetch_file_list_with_cache(str(shard.model_meta.model_id), revision, recursive=False)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x.path))
file_progress: Dict[str, RepoFileDownloadProgress] = {}
def on_progress_wrapper(file: FileListEntry, curr_bytes: int, total_bytes: int):
@@ -389,7 +389,7 @@ async def download_shard(shard: ShardMetadata,
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=curr_bytes,
@@ -400,11 +400,11 @@ async def download_shard(shard: ShardMetadata,
status="complete" if curr_bytes == total_bytes else "in_progress",
start_time=start_time,
)
on_progress(shard, calculate_repo_progress(shard, str(shard.model_id), revision, file_progress, all_start_time))
on_progress(shard, calculate_repo_progress(shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time))
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir/file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=downloaded_bytes,
@@ -419,10 +419,10 @@ async def download_shard(shard: ShardMetadata,
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file: FileListEntry):
async with semaphore:
await download_file_with_retry(str(shard.model_id), revision, file.path, target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
await download_file_with_retry(str(shard.model_meta.model_id), revision, file.path, target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download:
await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(shard, str(shard.model_id), revision, file_progress, all_start_time)
final_repo_progress = calculate_repo_progress(shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time)
on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
return target_dir/gguf.path, final_repo_progress

View File

@@ -20,7 +20,7 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
model_meta = await get_model_meta(model_id)
# print(f"build_base_shard {model_id=} {model_meta=}")
return PipelineShardMetadata(
model_id=model_id,
model_meta=model_meta,
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
@@ -34,7 +34,7 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
if base_shard is None:
return None
return PipelineShardMetadata(
model_id=base_shard.model_id,
model_meta=base_shard.model_meta,
partition_strategy=base_shard.partition_strategy,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
@@ -73,13 +73,13 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader.on_progress(callback)
async def ensure_shard(self, shard: ShardMetadata, config_only: bool = False) -> Path:
if (shard.model_id, shard) in self.cache:
if (shard.model_meta.model_id, shard) in self.cache:
# print(f"ensure_shard cache hit {shard=}")
return self.cache[(shard.model_id, shard)]
return self.cache[(shard.model_meta.model_id, shard)]
# print(f"ensure_shard cache miss {shard=}")
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_id, shard)] = target_dir
self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(self) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:

View File

@@ -3,6 +3,7 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from shared.types.models import ModelMetadata
from shared.types.worker.shards import (
PartitionStrategy,
PipelineShardMetadata,
@@ -11,6 +12,7 @@ from shared.types.worker.shards import (
from worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
class ShardDownloader(ABC):
@abstractmethod
async def ensure_shard(self, shard: ShardMetadata, config_only: bool = False) -> Path:
@@ -42,7 +44,12 @@ class ShardDownloader(ABC):
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_id="noop",
model_meta=ModelMetadata(
model_id='noop',
pretty_name='noope',
storage_size_kilobytes=0,
n_layers=1
),
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
@@ -76,7 +83,12 @@ class NoopShardDownloader(ShardDownloader):
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_id="noop",
model_meta=ModelMetadata(
model_id='noop',
pretty_name='noope',
storage_size_kilobytes=0,
n_layers=1
),
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,

View File

@@ -57,7 +57,7 @@ class AssignedRunner(BaseModel):
@property
def is_downloaded(self) -> bool:
# TODO: Do this properly with huggingface validating each of the files.
return os.path.exists(build_model_path(self.shard_metadata.model_id))
return os.path.exists(build_model_path(self.shard_metadata.model_meta.model_id))
def status_update_event(self) -> RunnerStatusUpdated:
return RunnerStatusUpdated(

View File

@@ -185,7 +185,7 @@ class RunnerSupervisor:
yield TokenChunk(
task_id=task.task_id,
idx=token,
model=self.model_shard_meta.model_id,
model=self.model_shard_meta.model_meta.model_id,
chunk_data=TokenChunkData(
text=text,
token_id=token,

View File

@@ -7,7 +7,7 @@ from typing import Callable
import pytest
from shared.types.common import NodeId
from shared.types.models import ModelId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.tasks.common import (
ChatCompletionMessage,
@@ -30,7 +30,18 @@ from worker.main import Worker
@pytest.fixture
def pipeline_shard_meta(tmp_path: Path):
def model_meta() -> ModelMetadata:
# return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # we can't do this! as it's an async function :(
return ModelMetadata(
model_id='mlx-community/Llama-3.2-1B-Instruct-4bit',
pretty_name='llama3.2',
storage_size_kilobytes=10**6,
n_layers=16
)
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
@@ -44,8 +55,8 @@ def pipeline_shard_meta(tmp_path: Path):
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
model_id=ModelId(uuid.uuid4()),
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,

View File

@@ -1,29 +1,21 @@
import time
from typing import Callable
import pytest
from shared.types.models import ModelId
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.impl_shard_downloader import exo_shard_downloader
from worker.download.shard_downloader import ShardDownloader
@pytest.mark.asyncio
async def test_shard_downloader():
async def test_shard_downloader(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata]):
shard_downloader: ShardDownloader = exo_shard_downloader()
shard_downloader.on_progress(
lambda shard, progress: print(f"Download progress: {progress}")
)
shard_metadata = PipelineShardMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=100,
n_layers=100,
)
shard_metadata = pipeline_shard_meta(1, 0)
path = await shard_downloader.ensure_shard(shard_metadata)
assert path.exists()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Final, List, Optional, Type
from typing import Callable, Final, List, Optional, Type
import pytest
@@ -125,10 +125,12 @@ class RunnerContext:
instance_params: InstanceParams
# TODO: generalize this it's in conftest.
def _build_worker_state(
*,
tmp_path: Path,
node_id: NodeId,
pipeline_shard_metadata: PipelineShardMetadata,
runner_cases: List[RunnerCase],
) -> tuple[State, List[RunnerContext]]:
"""Construct a WorkerState plus per-runner context objects."""
@@ -145,18 +147,9 @@ def _build_worker_state(
model_subdir = tmp_path / f"runner_{idx}"
model_subdir.mkdir(exist_ok=True)
shard_metadata = PipelineShardMetadata(
device_rank=0,
world_size=1,
model_id=model_id,
start_layer=0,
end_layer=0,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={runner_id: shard_metadata},
runner_to_shard={runner_id: pipeline_shard_metadata},
node_to_runner={node_id: runner_id},
)
@@ -177,7 +170,7 @@ def _build_worker_state(
RunnerContext(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=shard_metadata,
shard_metadata=pipeline_shard_metadata,
instance_params=instance_params,
)
)
@@ -197,7 +190,7 @@ def _build_worker_state(
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
@pytest.mark.parametrize("case", TEST_CASES, ids=[case.id() for case in TEST_CASES])
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, pipeline_shard_meta: Callable[..., PipelineShardMetadata]) -> None:
"""Exercise Worker.plan across declarative scenarios."""
# Fresh identifier for isolation of node
@@ -207,6 +200,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
worker_state, runner_contexts = _build_worker_state(
tmp_path=tmp_path,
node_id=node_id,
pipeline_shard_metadata=pipeline_shard_meta(1, 0),
runner_cases=case.runners,
)
@@ -234,7 +228,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
)
worker.assigned_runners[ctx.runner_id] = assigned_runner
path_downloaded_map[str(build_model_path(ctx.shard_metadata.model_id))] = runner_case.downloaded
path_downloaded_map[str(build_model_path(ctx.shard_metadata.model_meta.model_id))] = runner_case.downloaded
# Stub filesystem existence check ------------------------------------------------------
from worker import main as worker_main # local import for module-scoped os