mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Fix tests!
This commit is contained in:
@@ -1,28 +1,26 @@
|
||||
# type: ignore
|
||||
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from mlx_lm.utils import load_model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer # type: ignore
|
||||
from mlx_lm.utils import load_model # type: ignore
|
||||
from pydantic import RootModel
|
||||
|
||||
from engines.mlx.auto_parallel import auto_parallel
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import ShardMeta
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
from worker.download.download_utils import build_model_path
|
||||
from worker.runner.communication import runner_print
|
||||
|
||||
|
||||
def mx_barrier():
|
||||
mx.eval(
|
||||
mx.eval( # type: ignore
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
@@ -35,7 +33,7 @@ class HostList(RootModel[list[str]]):
|
||||
return cls(root=[str(host) for host in hosts])
|
||||
|
||||
|
||||
def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
|
||||
def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: # type: ignore
|
||||
"""
|
||||
Initialize the MLX distributed (runs in thread pool)
|
||||
"""
|
||||
@@ -62,7 +60,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
model_shard_meta: ShardMeta,
|
||||
model_shard_meta: ShardMetadata,
|
||||
hosts: list[Host],
|
||||
) -> tuple[nn.Module, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
@@ -71,19 +69,22 @@ def initialize_mlx(
|
||||
mx.random.seed(42)
|
||||
if len(hosts) > 1:
|
||||
mlx_distributed_init(model_shard_meta.device_rank, hosts)
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7) # type: ignore
|
||||
|
||||
model, tokenizer = shard_and_load(model_shard_meta)
|
||||
|
||||
return model, tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
runner_print(f"loading model from {model_shard_meta.model_path}")
|
||||
def shard_and_load(model_shard_meta: ShardMetadata) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(model_shard_meta.model_meta.model_id)
|
||||
|
||||
model, config = load_model(model_shard_meta.model_path, lazy=True, strict=False)
|
||||
runner_print(f"loading model from {model_path}")
|
||||
|
||||
tokenizer = load_tokenizer(model_shard_meta.model_path)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False) # type: ignore
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
model = auto_parallel(model, model_shard_meta)
|
||||
|
||||
@@ -107,18 +108,18 @@ async def apply_chat_template(
|
||||
# Filter out None values, keeping only 'role' and 'content' keys
|
||||
formatted_messages = []
|
||||
for message in messages_dicts:
|
||||
filtered_message = {k: v for k, v in message.items() if v is not None}
|
||||
filtered_message: dict[str, Any] = {k: v for k, v in message.items() if v is not None} # type: ignore
|
||||
# Verify we have exactly the expected keys
|
||||
assert set(filtered_message.keys()) == {"role", "content"}, (
|
||||
f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}"
|
||||
)
|
||||
formatted_messages.append(filtered_message)
|
||||
formatted_messages.append(filtered_message) # type: ignore
|
||||
|
||||
messages_dicts = formatted_messages
|
||||
|
||||
prompt: str = await loop.run_in_executor(
|
||||
executor=mlx_executor,
|
||||
func=lambda: tokenizer.apply_chat_template(
|
||||
func=lambda: tokenizer.apply_chat_template( # type: ignore
|
||||
messages_dicts,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
|
||||
@@ -7,10 +7,9 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from shared.event_loops.commands import ExternalCommand
|
||||
from shared.types.events.components import Apply, EventFromEventLog
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.state import State
|
||||
from shared.types.events.components import Apply
|
||||
|
||||
|
||||
class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
|
||||
|
||||
@@ -8,13 +8,13 @@ from typing import (
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.state import State
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.state import State
|
||||
|
||||
|
||||
class EventFromEventLog[T: Event](BaseModel):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Literal, TypeAlias, TypeVar
|
||||
from typing import Annotated, Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
|
||||
|
||||
class PartitionStrategy(str, Enum):
|
||||
@@ -20,10 +20,10 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
|
||||
Replaces previous `Shard` object.
|
||||
"""
|
||||
|
||||
model_meta: ModelMetadata
|
||||
partition_strategy: PartitionStrategyT
|
||||
device_rank: int
|
||||
world_size: int
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
|
||||
@@ -47,7 +47,7 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
|
||||
return self.end_layer == self.n_layers - 1
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
|
||||
return hash((self.model_meta.model_id, self.start_layer, self.end_layer, self.n_layers))
|
||||
|
||||
|
||||
ShardMetadata = Annotated[
|
||||
@@ -57,17 +57,6 @@ ShardMetadataParser: TypeAdapter[ShardMetadata] = TypeAdapter(
|
||||
ShardMetadata
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience aliases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# "ShardMeta" is a widely-used alias for the concrete, fully-parameterised
|
||||
# `ShardMetadata` type. Defining it here avoids repetitive generic
|
||||
# parameters at call-sites and resolves unknown-import diagnostics in
|
||||
# downstream modules.
|
||||
|
||||
ShardMeta: TypeAlias = ShardMetadata
|
||||
|
||||
|
||||
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
|
||||
"""
|
||||
|
||||
38
worker/download/conftest.py
Normal file
38
worker/download/conftest.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.model_meta import _get_model_meta # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_meta() -> ModelMetadata:
|
||||
return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path):
|
||||
def _pipeline_shard_meta(
|
||||
num_nodes: int = 1, device_rank: int = 0
|
||||
) -> PipelineShardMetadata:
|
||||
total_layers = 16
|
||||
layers_per_node = total_layers // num_nodes
|
||||
start_layer = device_rank * layers_per_node
|
||||
end_layer = (
|
||||
start_layer + layers_per_node
|
||||
if device_rank < num_nodes - 1
|
||||
else total_layers
|
||||
)
|
||||
|
||||
return PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=device_rank,
|
||||
n_layers=total_layers,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
world_size=num_nodes,
|
||||
)
|
||||
|
||||
return _pipeline_shard_meta
|
||||
@@ -293,10 +293,10 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_id))
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except Exception:
|
||||
print(f"Error getting weight map for {shard.model_id=}")
|
||||
print(f"Error getting weight map for {shard.model_meta.model_id=}")
|
||||
traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
@@ -360,27 +360,27 @@ async def download_shard(shard: ShardMetadata,
|
||||
skip_download: bool = False,
|
||||
allow_patterns: List[str] | None = None) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
print(f"Downloading {shard.model_id=}")
|
||||
print(f"Downloading {shard.model_meta.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_id)):
|
||||
print(f"Using local model path {shard.model_id}")
|
||||
local_path = Path(str(shard.model_id))
|
||||
return local_path, await download_progress_for_local_path(str(shard.model_id), shard, local_path)
|
||||
if await aios.path.exists(str(shard.model_meta.model_id)):
|
||||
print(f"Using local model path {shard.model_meta.model_id}")
|
||||
local_path = Path(str(shard.model_meta.model_id))
|
||||
return local_path, await download_progress_for_local_path(str(shard.model_meta.model_id), shard, local_path)
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir()/str(shard.model_id).replace("/", "--")
|
||||
target_dir = await ensure_models_dir()/str(shard.model_meta.model_id).replace("/", "--")
|
||||
if not skip_download:
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
||||
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
file_list = await fetch_file_list_with_cache(str(shard.model_id), revision, recursive=False)
|
||||
file_list = await fetch_file_list_with_cache(str(shard.model_meta.model_id), revision, recursive=False)
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x.path))
|
||||
file_progress: Dict[str, RepoFileDownloadProgress] = {}
|
||||
def on_progress_wrapper(file: FileListEntry, curr_bytes: int, total_bytes: int):
|
||||
@@ -389,7 +389,7 @@ async def download_shard(shard: ShardMetadata,
|
||||
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=curr_bytes,
|
||||
@@ -400,11 +400,11 @@ async def download_shard(shard: ShardMetadata,
|
||||
status="complete" if curr_bytes == total_bytes else "in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
on_progress(shard, calculate_repo_progress(shard, str(shard.model_id), revision, file_progress, all_start_time))
|
||||
on_progress(shard, calculate_repo_progress(shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time))
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir/file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=downloaded_bytes,
|
||||
@@ -419,10 +419,10 @@ async def download_shard(shard: ShardMetadata,
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file: FileListEntry):
|
||||
async with semaphore:
|
||||
await download_file_with_retry(str(shard.model_id), revision, file.path, target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
await download_file_with_retry(str(shard.model_meta.model_id), revision, file.path, target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
if not skip_download:
|
||||
await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||
final_repo_progress = calculate_repo_progress(shard, str(shard.model_id), revision, file_progress, all_start_time)
|
||||
final_repo_progress = calculate_repo_progress(shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time)
|
||||
on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
return target_dir/gguf.path, final_repo_progress
|
||||
|
||||
@@ -20,7 +20,7 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
|
||||
model_meta = await get_model_meta(model_id)
|
||||
# print(f"build_base_shard {model_id=} {model_meta=}")
|
||||
return PipelineShardMetadata(
|
||||
model_id=model_id,
|
||||
model_meta=model_meta,
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
@@ -34,7 +34,7 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
|
||||
if base_shard is None:
|
||||
return None
|
||||
return PipelineShardMetadata(
|
||||
model_id=base_shard.model_id,
|
||||
model_meta=base_shard.model_meta,
|
||||
partition_strategy=base_shard.partition_strategy,
|
||||
device_rank=base_shard.device_rank,
|
||||
world_size=base_shard.world_size,
|
||||
@@ -73,13 +73,13 @@ class CachedShardDownloader(ShardDownloader):
|
||||
self.shard_downloader.on_progress(callback)
|
||||
|
||||
async def ensure_shard(self, shard: ShardMetadata, config_only: bool = False) -> Path:
|
||||
if (shard.model_id, shard) in self.cache:
|
||||
if (shard.model_meta.model_id, shard) in self.cache:
|
||||
# print(f"ensure_shard cache hit {shard=}")
|
||||
return self.cache[(shard.model_id, shard)]
|
||||
return self.cache[(shard.model_meta.model_id, shard)]
|
||||
|
||||
# print(f"ensure_shard cache miss {shard=}")
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
|
||||
self.cache[(shard.model_id, shard)] = target_dir
|
||||
self.cache[(shard.model_meta.model_id, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.worker.shards import (
|
||||
PartitionStrategy,
|
||||
PipelineShardMetadata,
|
||||
@@ -11,6 +12,7 @@ from shared.types.worker.shards import (
|
||||
from worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(self, shard: ShardMetadata, config_only: bool = False) -> Path:
|
||||
@@ -42,7 +44,12 @@ class ShardDownloader(ABC):
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_id="noop",
|
||||
model_meta=ModelMetadata(
|
||||
model_id='noop',
|
||||
pretty_name='noope',
|
||||
storage_size_kilobytes=0,
|
||||
n_layers=1
|
||||
),
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
@@ -76,7 +83,12 @@ class NoopShardDownloader(ShardDownloader):
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_id="noop",
|
||||
model_meta=ModelMetadata(
|
||||
model_id='noop',
|
||||
pretty_name='noope',
|
||||
storage_size_kilobytes=0,
|
||||
n_layers=1
|
||||
),
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
|
||||
@@ -57,7 +57,7 @@ class AssignedRunner(BaseModel):
|
||||
@property
|
||||
def is_downloaded(self) -> bool:
|
||||
# TODO: Do this properly with huggingface validating each of the files.
|
||||
return os.path.exists(build_model_path(self.shard_metadata.model_id))
|
||||
return os.path.exists(build_model_path(self.shard_metadata.model_meta.model_id))
|
||||
|
||||
def status_update_event(self) -> RunnerStatusUpdated:
|
||||
return RunnerStatusUpdated(
|
||||
|
||||
@@ -185,7 +185,7 @@ class RunnerSupervisor:
|
||||
yield TokenChunk(
|
||||
task_id=task.task_id,
|
||||
idx=token,
|
||||
model=self.model_shard_meta.model_id,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
chunk_data=TokenChunkData(
|
||||
text=text,
|
||||
token_id=token,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable
|
||||
import pytest
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionMessage,
|
||||
@@ -30,7 +30,18 @@ from worker.main import Worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_shard_meta(tmp_path: Path):
|
||||
def model_meta() -> ModelMetadata:
|
||||
# return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # we can't do this! as it's an async function :(
|
||||
return ModelMetadata(
|
||||
model_id='mlx-community/Llama-3.2-1B-Instruct-4bit',
|
||||
pretty_name='llama3.2',
|
||||
storage_size_kilobytes=10**6,
|
||||
n_layers=16
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]:
|
||||
def _pipeline_shard_meta(
|
||||
num_nodes: int = 1, device_rank: int = 0
|
||||
) -> PipelineShardMetadata:
|
||||
@@ -44,8 +55,8 @@ def pipeline_shard_meta(tmp_path: Path):
|
||||
)
|
||||
|
||||
return PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=device_rank,
|
||||
model_id=ModelId(uuid.uuid4()),
|
||||
n_layers=total_layers,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
|
||||
@@ -1,29 +1,21 @@
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shard_downloader():
|
||||
async def test_shard_downloader(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata]):
|
||||
shard_downloader: ShardDownloader = exo_shard_downloader()
|
||||
shard_downloader.on_progress(
|
||||
lambda shard, progress: print(f"Download progress: {progress}")
|
||||
)
|
||||
|
||||
shard_metadata = PipelineShardMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=100,
|
||||
n_layers=100,
|
||||
)
|
||||
shard_metadata = pipeline_shard_meta(1, 0)
|
||||
path = await shard_downloader.ensure_shard(shard_metadata)
|
||||
assert path.exists()
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Final, List, Optional, Type
|
||||
from typing import Callable, Final, List, Optional, Type
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -125,10 +125,12 @@ class RunnerContext:
|
||||
instance_params: InstanceParams
|
||||
|
||||
|
||||
# TODO: generalize this it's in conftest.
|
||||
def _build_worker_state(
|
||||
*,
|
||||
tmp_path: Path,
|
||||
node_id: NodeId,
|
||||
pipeline_shard_metadata: PipelineShardMetadata,
|
||||
runner_cases: List[RunnerCase],
|
||||
) -> tuple[State, List[RunnerContext]]:
|
||||
"""Construct a WorkerState plus per-runner context objects."""
|
||||
@@ -145,18 +147,9 @@ def _build_worker_state(
|
||||
model_subdir = tmp_path / f"runner_{idx}"
|
||||
model_subdir.mkdir(exist_ok=True)
|
||||
|
||||
shard_metadata = PipelineShardMetadata(
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
model_id=model_id,
|
||||
start_layer=0,
|
||||
end_layer=0,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={runner_id: shard_metadata},
|
||||
runner_to_shard={runner_id: pipeline_shard_metadata},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
|
||||
@@ -177,7 +170,7 @@ def _build_worker_state(
|
||||
RunnerContext(
|
||||
runner_id=runner_id,
|
||||
instance_id=instance_id,
|
||||
shard_metadata=shard_metadata,
|
||||
shard_metadata=pipeline_shard_metadata,
|
||||
instance_params=instance_params,
|
||||
)
|
||||
)
|
||||
@@ -197,7 +190,7 @@ def _build_worker_state(
|
||||
|
||||
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=[case.id() for case in TEST_CASES])
|
||||
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, pipeline_shard_meta: Callable[..., PipelineShardMetadata]) -> None:
|
||||
"""Exercise Worker.plan across declarative scenarios."""
|
||||
|
||||
# Fresh identifier for isolation of node
|
||||
@@ -207,6 +200,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
|
||||
worker_state, runner_contexts = _build_worker_state(
|
||||
tmp_path=tmp_path,
|
||||
node_id=node_id,
|
||||
pipeline_shard_metadata=pipeline_shard_meta(1, 0),
|
||||
runner_cases=case.runners,
|
||||
)
|
||||
|
||||
@@ -234,7 +228,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
|
||||
)
|
||||
worker.assigned_runners[ctx.runner_id] = assigned_runner
|
||||
|
||||
path_downloaded_map[str(build_model_path(ctx.shard_metadata.model_id))] = runner_case.downloaded
|
||||
path_downloaded_map[str(build_model_path(ctx.shard_metadata.model_meta.model_id))] = runner_case.downloaded
|
||||
|
||||
# Stub filesystem existence check ------------------------------------------------------
|
||||
from worker import main as worker_main # local import for module-scoped os
|
||||
|
||||
Reference in New Issue
Block a user