Files
exo/master/tests/test_master.py
Alex Cheema 92c9688bf0 Remove rust
2025-08-02 08:16:39 -07:00

176 lines
6.5 KiB
Python

import asyncio
import tempfile
from logging import Logger
from pathlib import Path
from typing import List, Sequence
import pytest
from master.main import Master
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import NodeId
from shared.types.events import Event, EventFromEventLog, Heartbeat, TaskCreated
from shared.types.events._events import (
InstanceCreated,
NodePerformanceMeasured,
TopologyNodeCreated,
)
from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CommandId,
CreateInstanceCommand,
)
from shared.types.models import ModelMetadata
from shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
from shared.utils import Keypair
def _create_forwarder_dummy_binary() -> Path:
path = Path(tempfile.mktemp()) / "forwarder.bin"
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
path.chmod(0o755)
return path
@pytest.mark.asyncio
async def test_master():
logger = Logger(name='test_master_logger')
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
await global_events.delete_all_events()
async def _get_events() -> Sequence[EventFromEventLog[Event]]:
orig_events = await global_events.get_events_since(0)
override_idx_in_log = 1
events: List[EventFromEventLog[Event]] = []
for e in orig_events:
if isinstance(e.event, Heartbeat):
continue
events.append(EventFromEventLog(
event=e.event,
origin=e.origin,
idx_in_log=override_idx_in_log
))
override_idx_in_log += 1
return events
command_buffer: List[Command] = []
forwarder_binary_path = _create_forwarder_dummy_binary()
node_id_keypair = Keypair.generate_ed25519()
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
master = Master(node_id_keypair, node_id, command_buffer=command_buffer, global_events=global_events,
forwarder_binary_path=forwarder_binary_path, logger=logger, worker_events=global_events)
asyncio.create_task(master.run())
# wait for initial topology event
while len(list(master.state.topology.list_nodes())) == 0:
print("waiting")
await asyncio.sleep(0.001)
# inject a NodePerformanceProfile event
await event_log_manager.global_events.append_events([
NodePerformanceMeasured(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(ram_total=678948*1024, ram_available=678948*1024, swap_total=0, swap_available=0),
network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=0)
)
)
], origin=node_id)
while len(master.state.node_profiles) == 0:
await asyncio.sleep(0.001)
command_buffer.append(CreateInstanceCommand(
command_id=CommandId(),
instance_id=InstanceId(),
model_meta=ModelMetadata(
model_id="llama-3.2-1b",
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size_kilobytes=678948
)
))
while len(master.state.instances.keys()) == 0:
await asyncio.sleep(0.001)
command_buffer.append(
ChatCompletionCommand(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
)
)
)
while len(await _get_events()) < 4:
await asyncio.sleep(0.001)
events = await _get_events()
print(events)
assert len(events) == 4
assert events[0].idx_in_log == 1
assert isinstance(events[0].event, TopologyNodeCreated)
assert isinstance(events[1].event, NodePerformanceMeasured)
assert isinstance(events[2].event, InstanceCreated)
runner_id = list(events[2].event.instance.shard_assignments.runner_to_shard.keys())[0]
assert events[2].event == InstanceCreated(
instance=Instance(
instance_id=events[2].event.instance.instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=ShardAssignments(
model_id="llama-3.2-1b",
runner_to_shard={
(runner_id): PipelineShardMetadata(
partition_strategy=PartitionStrategy.pipeline,
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id="llama-3.2-1b",
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size_kilobytes=678948
),
device_rank=0,
world_size=1
)
},
node_to_runner={node_id: runner_id}
),
hosts=[]
)
)
assert isinstance(events[3].event, TaskCreated)
assert events[3].event == TaskCreated(
task_id=events[3].event.task_id,
task=ChatCompletionTask(
task_id=events[3].event.task_id,
command_id=events[3].event.task.command_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=events[3].event.task.instance_id,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
)
)
)
assert len(command_buffer) == 0