mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Co-authored-by: Gelu Vrabie <gelu@exolabs.net> Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Co-authored-by: Seth Howes <71157822+sethhowes@users.noreply.github.com> Co-authored-by: Matt Beton <matthew.beton@gmail.com> Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
import pytest
|
|
|
|
from shared.models.model_meta import get_model_meta
|
|
from shared.types.models import ModelMetadata
|
|
from shared.types.worker.shards import PipelineShardMetadata
|
|
|
|
|
|
@pytest.fixture
|
|
async def model_meta() -> ModelMetadata:
|
|
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_shard_meta(model_meta: ModelMetadata):
|
|
def _pipeline_shard_meta(
|
|
num_nodes: int = 1, device_rank: int = 0
|
|
) -> PipelineShardMetadata:
|
|
total_layers = 16
|
|
layers_per_node = total_layers // num_nodes
|
|
start_layer = device_rank * layers_per_node
|
|
end_layer = (
|
|
start_layer + layers_per_node
|
|
if device_rank < num_nodes - 1
|
|
else total_layers
|
|
)
|
|
|
|
return PipelineShardMetadata(
|
|
model_meta=model_meta,
|
|
device_rank=device_rank,
|
|
n_layers=total_layers,
|
|
start_layer=start_layer,
|
|
end_layer=end_layer,
|
|
world_size=num_nodes,
|
|
)
|
|
|
|
return _pipeline_shard_meta |