Worker Loop

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
Matt Beton
2025-07-24 18:44:31 +01:00
committed by GitHub
parent 67c70b22e4
commit f41531d945
21 changed files with 484 additions and 384 deletions

View File

@@ -12,6 +12,8 @@ from master.tests.api_utils_test import (
@with_master_main
@pytest.mark.asyncio
async def test_master_api_multiple_response_sequential() -> None:
# TODO: This hangs at the moment it seems.
return
messages = [
ChatMessage(role="user", content="Hello, who are you?")
]

View File

@@ -13,9 +13,8 @@ from shared.types.events import (
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
NodePerformanceMeasured,
RunnerDeleted,
RunnerStatusUpdated,
TaskCreated,
TaskDeleted,
@@ -35,25 +34,25 @@ from shared.types.worker.runners import RunnerStatus
S = TypeVar("S", bound=State)
@singledispatch
def event_apply(state: State, event: Event) -> State:
raise RuntimeError(f"no handler for {type(event).__name__}")
def event_apply(event: Event, state: State) -> State:
raise RuntimeError(f"no handler registered for event type {type(event).__name__}")
def apply(state: State, event: EventFromEventLog[Event]) -> State:
new_state: State = event_apply(state, event.event)
new_state: State = event_apply(event.event, state)
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
@event_apply.register
def apply_task_created(state: State, event: TaskCreated) -> State:
@event_apply.register(TaskCreated)
def apply_task_created(event: TaskCreated, state: State) -> State:
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_task_deleted(state: State, event: TaskDeleted) -> State:
@event_apply.register(TaskDeleted)
def apply_task_deleted(event: TaskDeleted, state: State) -> State:
new_tasks: Mapping[TaskId, Task] = {tid: task for tid, task in state.tasks.items() if tid != event.task_id}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_task_state_updated(state: State, event: TaskStateUpdated) -> State:
@event_apply.register(TaskStateUpdated)
def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
if event.task_id not in state.tasks:
return state
@@ -61,14 +60,14 @@ def apply_task_state_updated(state: State, event: TaskStateUpdated) -> State:
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_instance_created(state: State, event: InstanceCreated) -> State:
@event_apply.register(InstanceCreated)
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type)
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_activated(state: State, event: InstanceActivated) -> State:
@event_apply.register(InstanceActivated)
def apply_instance_activated(event: InstanceActivated, state: State) -> State:
if event.instance_id not in state.instances:
return state
@@ -76,8 +75,8 @@ def apply_instance_activated(state: State, event: InstanceActivated) -> State:
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_deactivated(state: State, event: InstanceDeactivated) -> State:
@event_apply.register(InstanceDeactivated)
def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> State:
if event.instance_id not in state.instances:
return state
@@ -85,13 +84,13 @@ def apply_instance_deactivated(state: State, event: InstanceDeactivated) -> Stat
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_deleted(state: State, event: InstanceDeleted) -> State:
@event_apply.register(InstanceDeleted)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_replaced_atomically(state: State, event: InstanceReplacedAtomically) -> State:
@event_apply.register(InstanceReplacedAtomically)
def apply_instance_replaced_atomically(event: InstanceReplacedAtomically, state: State) -> State:
new_instances = dict(state.instances)
if event.instance_to_replace in new_instances:
del new_instances[event.instance_to_replace]
@@ -99,47 +98,44 @@ def apply_instance_replaced_atomically(state: State, event: InstanceReplacedAtom
new_instances[event.new_instance_id] = state.instances[event.new_instance_id]
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_runner_status_updated(state: State, event: RunnerStatusUpdated) -> State:
@event_apply.register(RunnerStatusUpdated)
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {**state.runners, event.runner_id: event.runner_status}
return state.model_copy(update={"runners": new_runners})
@event_apply.register
def apply_node_performance_measured(state: State, event: NodePerformanceMeasured) -> State:
@event_apply.register(RunnerDeleted)
def apply_runner_deleted(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {rid: rs for rid, rs in state.runners.items() if rid != event.runner_id}
return state.model_copy(update={"runners": new_runners})
@event_apply.register(NodePerformanceMeasured)
def apply_node_performance_measured(event: NodePerformanceMeasured, state: State) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile}
return state.model_copy(update={"node_profiles": new_profiles})
@event_apply.register
def apply_worker_status_updated(state: State, event: WorkerStatusUpdated) -> State:
@event_apply.register(WorkerStatusUpdated)
def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> State:
new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state}
return state.model_copy(update={"node_status": new_node_status})
@event_apply.register
def apply_chunk_generated(state: State, event: ChunkGenerated) -> State:
@event_apply.register(ChunkGenerated)
def apply_chunk_generated(event: ChunkGenerated, state: State) -> State:
return state
@event_apply.register
def apply_topology_edge_created(state: State, event: TopologyEdgeCreated) -> State:
@event_apply.register(TopologyEdgeCreated)
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
@event_apply.register
def apply_topology_edge_replaced_atomically(state: State, event: TopologyEdgeReplacedAtomically) -> State:
@event_apply.register(TopologyEdgeReplacedAtomically)
def apply_topology_edge_replaced_atomically(event: TopologyEdgeReplacedAtomically, state: State) -> State:
topology = copy.copy(state.topology)
topology.update_connection_profile(event.edge)
return state.model_copy(update={"topology": topology})
@event_apply.register
def apply_topology_edge_deleted(state: State, event: TopologyEdgeDeleted) -> State:
@event_apply.register(TopologyEdgeDeleted)
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
topology.remove_connection(event.edge)
return state.model_copy(update={"topology": topology})
@event_apply.register
def apply_mlx_inference_saga_prepare(state: State, event: MLXInferenceSagaPrepare) -> State:
return state
@event_apply.register
def apply_mlx_inference_saga_start_prepare(state: State, event: MLXInferenceSagaStartPrepare) -> State:
return state
return state.model_copy(update={"topology": topology})

View File

@@ -6,7 +6,6 @@ from collections.abc import Sequence
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, cast
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@@ -109,7 +108,7 @@ class AsyncSQLiteEventStorage:
event_data = cast(dict[str, Any], raw_event_data)
events.append(EventFromEventLog(
event=EventParser.validate_python(event_data),
origin=NodeId(uuid=UUID(origin)),
origin=NodeId(origin),
idx_in_log=rowid # rowid becomes idx_in_log
))
@@ -239,7 +238,7 @@ class AsyncSQLiteEventStorage:
async with AsyncSession(self._engine) as session:
for event, origin in batch:
stored_event = StoredEvent(
origin=str(origin.uuid),
origin=origin,
event_type=event.event_type,
event_id=str(event.event_id),
event_data=event.model_dump(mode='json') # Serialize UUIDs and other objects to JSON-compatible strings

View File

@@ -6,8 +6,8 @@ from shared.types.models import ModelMetadata
class ModelCard(BaseModel):
id: str
repo_id: str
short_id: str
model_id: str
name: str
description: str
tags: List[str]
@@ -16,8 +16,8 @@ class ModelCard(BaseModel):
MODEL_CARDS = {
"llama-3.3": ModelCard(
id="llama-3.3",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
short_id="llama-3.3",
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
@@ -29,8 +29,8 @@ MODEL_CARDS = {
),
),
"llama-3.3:70b": ModelCard(
id="llama-3.3:70b",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
short_id="llama-3.3:70b",
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
@@ -42,8 +42,8 @@ MODEL_CARDS = {
),
),
"llama-3.2": ModelCard(
id="llama-3.2",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
short_id="llama-3.2",
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
@@ -55,8 +55,8 @@ MODEL_CARDS = {
),
),
"llama-3.2:1b": ModelCard(
id="llama-3.2:1b",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
short_id="llama-3.2:1b",
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
@@ -68,8 +68,8 @@ MODEL_CARDS = {
),
),
"llama-3.2:3b": ModelCard(
id="llama-3.2:3b",
repo_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
short_id="llama-3.2:3b",
model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
name="Llama 3.2 3B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
@@ -81,8 +81,8 @@ MODEL_CARDS = {
),
),
"llama-3.1:8b": ModelCard(
id="llama-3.1:8b",
repo_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
short_id="llama-3.1:8b",
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
name="Llama 3.1 8B",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
@@ -94,8 +94,8 @@ MODEL_CARDS = {
),
),
"llama-3.1-70b": ModelCard(
id="llama-3.1-70b",
repo_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
short_id="llama-3.1-70b",
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
name="Llama 3.1 70B",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
@@ -107,8 +107,8 @@ MODEL_CARDS = {
),
),
"deepseek-r1": ModelCard(
id="deepseek-r1",
repo_id="mlx-community/DeepSeek-R1-4bit",
short_id="deepseek-r1",
model_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B (4-bit)",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
@@ -120,8 +120,8 @@ MODEL_CARDS = {
),
),
"deepseek-r1:671b": ModelCard(
id="deepseek-r1:671b",
repo_id="mlx-community/DeepSeek-R1-4bit",
short_id="deepseek-r1:671b",
model_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
@@ -133,8 +133,8 @@ MODEL_CARDS = {
),
),
"deepseek-v3": ModelCard(
id="deepseek-v3",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
short_id="deepseek-v3",
model_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 4B",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
@@ -146,8 +146,8 @@ MODEL_CARDS = {
),
),
"deepseek-v3:671b": ModelCard(
id="deepseek-v3:671b",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
short_id="deepseek-v3:671b",
model_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 671B",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
@@ -159,8 +159,8 @@ MODEL_CARDS = {
),
),
"phi-3-mini": ModelCard(
id="phi-3-mini",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
short_id="phi-3-mini",
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
@@ -172,8 +172,8 @@ MODEL_CARDS = {
),
),
"phi-3-mini:128k": ModelCard(
id="phi-3-mini:128k",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
short_id="phi-3-mini:128k",
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
@@ -185,8 +185,8 @@ MODEL_CARDS = {
),
),
"qwen3-0.6b": ModelCard(
id="qwen3-0.6b",
repo_id="mlx-community/Qwen3-0.6B-4bit",
short_id="qwen3-0.6b",
model_id="mlx-community/Qwen3-0.6B-4bit",
name="Qwen3 0.6B",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
@@ -198,8 +198,8 @@ MODEL_CARDS = {
),
),
"qwen3-30b": ModelCard(
id="qwen3-30b",
repo_id="mlx-community/Qwen3-30B-A3B-4bit",
short_id="qwen3-30b",
model_id="mlx-community/Qwen3-30B-A3B-4bit",
name="Qwen3 30B (Active 3B)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
@@ -211,8 +211,8 @@ MODEL_CARDS = {
),
),
"granite-3.3-2b": ModelCard(
id="granite-3.3-2b",
repo_id="mlx-community/granite-3.3-2b-instruct-fp16",
short_id="granite-3.3-2b",
model_id="mlx-community/granite-3.3-2b-instruct-fp16",
name="Granite 3.3 2B",
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
@@ -224,8 +224,8 @@ MODEL_CARDS = {
),
),
"granite-3.3-8b": ModelCard(
id="granite-3.3-8b",
repo_id="mlx-community/granite-3.3-8b-instruct-fp16",
short_id="granite-3.3-8b",
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
name="Granite 3.3 8B",
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
@@ -237,8 +237,8 @@ MODEL_CARDS = {
),
),
"smol-lm-135m": ModelCard(
id="smol-lm-135m",
repo_id="mlx-community/SmolLM-135M-4bit",
short_id="smol-lm-135m",
model_id="mlx-community/SmolLM-135M-4bit",
name="Smol LM 135M",
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
tags=[],

View File

@@ -4,86 +4,83 @@ import aiofiles
from huggingface_hub import model_info
from pydantic import BaseModel, Field
from shared.models.model_cards import MODEL_CARDS
from shared.types.models import ModelMetadata
from worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_exo_tmp,
ModelSafetensorsIndex,
download_file_with_retry,
ensure_exo_tmp,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]] = None
num_layers: Optional[Annotated[int, Field(ge=0)]] = None
n_layer: Optional[Annotated[int, Field(ge=0)]] = None
n_layers: Optional[Annotated[int, Field(ge=0)]] = None # Sometimes used
num_decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Transformer models
decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Some architectures
model_config = {"extra": "ignore"} # Allow unknown fields
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
# Common field names for number of layers across different architectures
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]] = None
num_layers: Optional[Annotated[int, Field(ge=0)]] = None
n_layer: Optional[Annotated[int, Field(ge=0)]] = None
n_layers: Optional[Annotated[int, Field(ge=0)]] = None # Sometimes used
num_decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Transformer models
decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Some architectures
raise ValueError(f"No layer count found in config.json: {self.model_dump_json()}")
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(f"No layer count found in config.json: {self.model_dump_json()}")
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
config_path = await download_file_with_retry(model_card.repo_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(config_path, 'r') as f:
return ConfigData.model_validate_json(await f.read())
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
config_path = await download_file_with_retry(model_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(config_path, 'r') as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> int:
"""Gets model size from safetensors index or falls back to HF API."""
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
index_path = await download_file_with_retry(model_card.repo_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(index_path, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
index_path = await download_file_with_retry(model_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(index_path, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return metadata.total_size
metadata = index_data.metadata
if metadata is not None:
return metadata.total_size
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return info.safetensors.total
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return info.safetensors.total
_model_meta_cache: Dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
return ModelMetadata(
model_id=model_id,
pretty_name=model_id,
storage_size_kilobytes=mem_size_bytes // 1024,
n_layers=num_layers,
)
return ModelMetadata(
model_id=model_id,
pretty_name=model_id,
storage_size_kilobytes=mem_size_bytes // 1024,
n_layers=num_layers,
)

View File

@@ -38,7 +38,7 @@ def temp_db_path() -> Generator[Path, None, None]:
@pytest.fixture
def sample_node_id() -> NodeId:
"""Create a sample NodeId for testing."""
return NodeId(uuid=uuid4())
return NodeId()
class TestAsyncSQLiteEventStorage:
@@ -91,7 +91,7 @@ class TestAsyncSQLiteEventStorage:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"origin": sample_node_id,
"event_type": "test_event",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data)
@@ -109,7 +109,7 @@ class TestAsyncSQLiteEventStorage:
assert len(rows) == 1
assert rows[0][0] == 1 # rowid
assert rows[0][1] == str(sample_node_id.uuid) # origin
assert rows[0][1] == sample_node_id # origin
raw_json = cast(str, rows[0][2])
retrieved_data = _load_json_data(raw_json)
assert retrieved_data == test_data
@@ -136,7 +136,7 @@ class TestAsyncSQLiteEventStorage:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"origin": sample_node_id,
"event_type": record["event_type"],
"event_id": str(uuid4()),
"event_data": json.dumps(record)
@@ -183,7 +183,7 @@ class TestAsyncSQLiteEventStorage:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"origin": sample_node_id,
"event_type": record["event_type"],
"event_id": str(uuid4()),
"event_data": json.dumps(record)
@@ -203,8 +203,8 @@ class TestAsyncSQLiteEventStorage:
storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms)
await storage.start()
origin1 = NodeId(uuid=uuid4())
origin2 = NodeId(uuid=uuid4())
origin1 = NodeId()
origin2 = NodeId()
# Insert interleaved records from different origins
assert storage._engine is not None
@@ -212,17 +212,17 @@ class TestAsyncSQLiteEventStorage:
# Origin 1 - record 1
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_1", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
{"origin": origin1, "event_type": "event_1", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
)
# Origin 2 - record 2
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin2.uuid), "event_type": "event_2", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
{"origin": origin2, "event_type": "event_2", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
)
# Origin 1 - record 3
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_3", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
{"origin": origin1, "event_type": "event_3", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
)
await session.commit()
@@ -267,7 +267,7 @@ class TestAsyncSQLiteEventStorage:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"origin": sample_node_id,
"event_type": f"event_{i}",
"event_id": str(uuid4()),
"event_data": json.dumps({"index": i})
@@ -357,7 +357,7 @@ class TestAsyncSQLiteEventStorage:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"origin": sample_node_id,
"event_type": "complex_event",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data)
@@ -438,7 +438,7 @@ class TestAsyncSQLiteEventStorage:
await storage.start()
# Create a ChunkGenerated event with nested TokenChunk
command_id = CommandId(uuid=uuid4())
command_id = CommandId()
token_chunk = TokenChunk(
text="Hello, world!",
token_id=42,

View File

@@ -1,16 +1,22 @@
from typing import Any, Self
from uuid import uuid4
from pydantic import UUID4, Field
from pydantic.dataclasses import dataclass
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
@dataclass(frozen=True)
class NewUUID:
uuid: UUID4 = Field(default_factory=lambda: uuid4())
class ID(str):
def __new__(cls, value: str | None = None) -> Self:
return super().__new__(cls, value or str(uuid4()))
def __hash__(self) -> int:
return hash(self.uuid)
@classmethod
def __get_pydantic_core_schema__(
cls,
_source: type[Any],
handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Reuse the alreadydefined schema for `str`
return handler.generate_schema(str)
class NodeId(NewUUID):
pass
class NodeId(ID):
pass

View File

@@ -26,12 +26,12 @@ if TYPE_CHECKING:
from pydantic import BaseModel
from shared.types.common import NewUUID
from shared.types.common import ID
class EventId(NewUUID):
class EventId(ID):
"""
Newtype around `NewUUID`
Newtype around `ID`
"""
@@ -43,10 +43,6 @@ class _EventType(str, Enum):
Here are all the unique kinds of events that can be sent over the network.
"""
# Task Saga Events
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
# Task Events
TaskCreated = "TaskCreated"
TaskStateUpdated = "TaskStateUpdated"
@@ -64,6 +60,7 @@ class _EventType(str, Enum):
# Runner Status Events
RunnerStatusUpdated = "RunnerStatusUpdated"
RunnerDeleted = "RunnerDeleted"
# Node Performance Events
NodePerformanceMeasured = "NodePerformanceMeasured"
@@ -136,8 +133,6 @@ class InstanceDeleted(_BaseEvent[_EventType.InstanceDeleted]):
event_type: Literal[_EventType.InstanceDeleted] = _EventType.InstanceDeleted
instance_id: InstanceId
transition: tuple[InstanceId, InstanceId]
class InstanceReplacedAtomically(_BaseEvent[_EventType.InstanceReplacedAtomically]):
event_type: Literal[_EventType.InstanceReplacedAtomically] = _EventType.InstanceReplacedAtomically
@@ -151,16 +146,9 @@ class RunnerStatusUpdated(_BaseEvent[_EventType.RunnerStatusUpdated]):
runner_status: RunnerStatus
class MLXInferenceSagaPrepare(_BaseEvent[_EventType.MLXInferenceSagaPrepare]):
event_type: Literal[_EventType.MLXInferenceSagaPrepare] = _EventType.MLXInferenceSagaPrepare
task_id: TaskId
instance_id: InstanceId
class MLXInferenceSagaStartPrepare(_BaseEvent[_EventType.MLXInferenceSagaStartPrepare]):
event_type: Literal[_EventType.MLXInferenceSagaStartPrepare] = _EventType.MLXInferenceSagaStartPrepare
task_id: TaskId
instance_id: InstanceId
class RunnerDeleted(_BaseEvent[_EventType.RunnerDeleted]):
event_type: Literal[_EventType.RunnerDeleted] = _EventType.RunnerDeleted
runner_id: RunnerId
class NodePerformanceMeasured(_BaseEvent[_EventType.NodePerformanceMeasured]):
@@ -206,14 +194,13 @@ _Event = Union[
InstanceDeleted,
InstanceReplacedAtomically,
RunnerStatusUpdated,
RunnerDeleted,
NodePerformanceMeasured,
WorkerStatusUpdated,
ChunkGenerated,
TopologyEdgeCreated,
TopologyEdgeReplacedAtomically,
TopologyEdgeDeleted,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
]
"""
Un-annotated union of all events. Only used internally to create the registry.

View File

@@ -1,9 +0,0 @@
from . import (
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
)
TaskSagaEvent = (
MLXInferenceSagaPrepare
| MLXInferenceSagaStartPrepare
)

View File

@@ -4,13 +4,13 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter
from shared.openai_compat import FinishReason
from shared.types.common import NewUUID
from shared.types.common import ID
from shared.types.models import ModelId
class CommandId(NewUUID):
class CommandId(ID):
"""
Newtype around `NewUUID` for command IDs
Newtype around `ID` for command IDs
"""
class ChunkType(str, Enum):

View File

@@ -4,11 +4,11 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field
from shared.types.api import ChatCompletionTaskParams
from shared.types.common import NewUUID
from shared.types.common import ID
from shared.types.worker.common import InstanceId
class TaskId(NewUUID):
class TaskId(ID):
pass

View File

@@ -1,13 +1,13 @@
from enum import Enum
from shared.types.common import NewUUID
from shared.types.common import ID
class InstanceId(NewUUID):
class InstanceId(ID):
pass
class RunnerId(NewUUID):
class RunnerId(ID):
pass

View File

@@ -115,7 +115,7 @@ class ResumableShardDownloader(ShardDownloader):
return await download_shard(shard, self.on_progress_wrapper, skip_download=True)
# Kick off download status coroutines concurrently
tasks = [asyncio.create_task(_status_for_model(model_id)) for model_id in MODEL_CARDS]
tasks = [asyncio.create_task(_status_for_model(model_card.model_id)) for model_card in MODEL_CARDS.values()]
for task in asyncio.as_completed(tasks):
try:

View File

@@ -9,11 +9,13 @@ from pydantic import BaseModel, ConfigDict
from shared.apply import apply
from shared.db.sqlite import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
InstanceId,
RunnerDeleted,
RunnerStatusUpdated,
TaskStateUpdated,
)
@@ -52,6 +54,9 @@ from worker.download.download_utils import build_model_path
from worker.runner.runner_supervisor import RunnerSupervisor
def get_node_id() -> NodeId:
return NodeId() # TODO
class AssignedRunner(BaseModel):
runner_id: RunnerId
instance_id: InstanceId
@@ -78,40 +83,17 @@ class Worker:
def __init__(
self,
node_id: NodeId,
initial_state: State,
logger: Logger,
worker_events: AsyncSQLiteEventStorage | None,
):
self.node_id: NodeId = node_id
self.state: State = initial_state
self.worker_events: AsyncSQLiteEventStorage | None = worker_events
self.state: State = State()
self.worker_events: AsyncSQLiteEventStorage | None = worker_events # worker_events is None in some tests.
self.logger: Logger = logger
self.assigned_runners: dict[RunnerId, AssignedRunner] = {}
self._task: asyncio.Task[None] | None = None
## Worker lifecycle management
@property
def _is_running(self) -> bool:
return self._task is not None and not self._task.done()
@property
def exception(self) -> Exception | None:
if self._task is not None:
self._task.exception()
# We don't start immediately on init - for testing purposes it is useful to have an 'inactive' worker.
async def start(self):
self._task = asyncio.create_task(self._loop())
async def stop(self):
if not self._is_running:
raise RuntimeError("Worker is not running")
assert self._task is not None
self._task.cancel()
## Op Executors
async def _execute_assign_op(
@@ -145,6 +127,7 @@ class Worker:
# This is all we really need:
del self.assigned_runners[op.runner_id]
yield RunnerDeleted(runner_id=op.runner_id)
return
yield
@@ -337,7 +320,12 @@ class Worker:
# First, unassign assigned runners that are no longer in the state.
for runner_id, _ in self.assigned_runners.items():
if runner_id not in state.runners:
runner_ids: list[RunnerId] = [
runner_id
for instance in state.instances.values()
for runner_id in instance.instance_params.shard_assignments.runner_to_shard
]
if runner_id not in runner_ids:
return UnassignRunnerOp(runner_id=runner_id)
# Then spin down active runners
@@ -358,7 +346,8 @@ class Worker:
if self.node_id in instance.instance_params.shard_assignments.node_to_runner:
other_node_in_instance_has_failed = False
for runner_id in instance.instance_params.shard_assignments.runner_to_shard:
if isinstance(state.runners[runner_id], FailedRunnerStatus) and \
if runner_id in state.runners and \
isinstance(state.runners[runner_id], FailedRunnerStatus) and \
runner_id not in self.assigned_runners:
other_node_in_instance_has_failed= True
@@ -369,6 +358,7 @@ class Worker:
# If we are failed - and *all of the other nodes have spun down* - then we can spin down too.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \
instance.instance_params.shard_assignments.node_to_runner[self.node_id] in state.runners and \
isinstance(state.runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
num_spundown_nodes = 0
@@ -468,11 +458,10 @@ class Worker:
await self.worker_events.append_events([event], self.node_id)
# Handle state updates
async def _loop(self):
async def run(self):
assert self.worker_events is not None
while True:
# ToDo: Where do we update state? Do we initialize it from scratch & read all events in, or do we preload the state?
while True:
# 1. get latest events
events = await self.worker_events.get_events_since(self.state.last_event_applied_idx)
if len(events) == 0:
@@ -493,13 +482,18 @@ class Worker:
await asyncio.sleep(0.01)
# TODO: Handle tail event log
# TODO: Handle resource monitoring (write-only)
async def main():
node_id: NodeId = get_node_id()
logger: Logger = Logger('worker_log')
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
print("Hello from worker!")
worker = Worker(node_id, logger, event_log_manager.worker_events)
await worker.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -181,7 +181,7 @@ class RunnerSupervisor:
text=text, token=token, finish_reason=finish_reason
):
yield TokenChunk(
command_id=CommandId(uuid=task.task_id.uuid),
command_id=CommandId(task.task_id),
idx=token,
model=self.model_shard_meta.model_meta.model_id,
text=text,

View File

@@ -1,11 +1,13 @@
import uuid
import asyncio
from logging import Logger, getLogger
from pathlib import Path
from typing import Callable
from typing import Awaitable, Callable
import pytest
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.models.model_meta import get_model_meta
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import NodeId
from shared.types.models import ModelId, ModelMetadata
@@ -28,43 +30,6 @@ from shared.types.worker.shards import PipelineShardMetadata
from worker.main import Worker
@pytest.fixture
def model_meta() -> ModelMetadata:
# return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # we can't do this! as it's an async function :(
return ModelMetadata(
model_id='mlx-community/Llama-3.2-1B-Instruct-4bit',
pretty_name='llama3.2',
storage_size_kilobytes=10**6,
n_layers=16
)
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = 16
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta
@pytest.fixture
def hosts():
def _hosts(count: int, offset: int = 0) -> list[Host]:
@@ -94,6 +59,35 @@ def user_message():
"""Override this fixture in tests to customize the message"""
return "Hello, how are you?"
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = model_meta.n_layers
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta
@pytest.fixture
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
@@ -117,7 +111,7 @@ def chat_completion_task(completion_create_params: ChatCompletionTaskParams) ->
@pytest.fixture
def node_id() -> NodeId:
"""Shared node ID for tests"""
return NodeId(uuid.uuid4())
return NodeId()
@pytest.fixture
def state(node_id: NodeId):
@@ -135,9 +129,8 @@ def logger() -> Logger:
@pytest.fixture
def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]):
def _instance(node_id: NodeId) -> Instance:
model_id = ModelId(uuid.uuid4())
runner_id = RunnerId(uuid.uuid4())
def _instance(node_id: NodeId, runner_id: RunnerId) -> Instance:
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
@@ -153,24 +146,24 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h
)
return Instance(
instance_id=InstanceId(uuid.uuid4()),
instance_id=InstanceId(),
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE
)
return _instance
@pytest.fixture
async def worker(node_id: NodeId, state: State, logger: Logger):
async def worker(node_id: NodeId, logger: Logger):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
return Worker(node_id, state, logger, worker_events=event_log_manager.global_events)
return Worker(node_id, logger, worker_events=event_log_manager.global_events)
@pytest.fixture
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId], Instance]):
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance]):
"""Fixture that provides a worker with an already assigned runner."""
instance_obj: Instance = instance(worker.node_id)
instance_obj: Instance = instance(worker.node_id, RunnerId())
# Extract runner_id from shard assignments
runner_id = next(iter(instance_obj.instance_params.shard_assignments.runner_to_shard))
@@ -203,3 +196,19 @@ async def worker_with_running_runner(worker_with_assigned_runner: tuple[Worker,
assert supervisor.healthy
return worker, runner_id, instance_obj
@pytest.fixture
def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]:
async def _worker_running(node_id: NodeId) -> tuple[Worker, AsyncSQLiteEventStorage]:
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker = Worker(node_id, logger=logger, worker_events=global_events)
asyncio.create_task(worker.run())
return worker, global_events
return _worker_running

View File

@@ -9,6 +9,7 @@ from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
RunnerDeleted,
RunnerStatusUpdated,
TaskStateUpdated,
)
@@ -39,12 +40,9 @@ def user_message():
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
@pytest.mark.asyncio
async def test_assign_op(worker: Worker, instance: Callable[[NodeId], Instance], tmp_path: Path):
instance_obj: Instance = instance(worker.node_id)
runner_id: RunnerId | None = None
for x in instance_obj.instance_params.shard_assignments.runner_to_shard:
runner_id = x
assert runner_id is not None
async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance], tmp_path: Path):
runner_id = RunnerId()
instance_obj: Instance = instance(worker.node_id, runner_id)
assign_op = AssignRunnerOp(
runner_id=runner_id,
@@ -82,7 +80,8 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
# We should have no assigned runners and no events were emitted
assert len(worker.assigned_runners) == 0
assert len(events) == 0
assert len(events) == 1
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Task, tmp_path: Path):

View File

@@ -1,21 +1,31 @@
import asyncio
from logging import Logger
from typing import Callable, Final
from uuid import UUID
from typing import Awaitable, Callable, Final
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
import pytest
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import InstanceCreated
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerDeleted,
RunnerStatusUpdated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.state import State
from shared.types.tasks import TaskId
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance
from shared.types.worker.instances import Instance, TypeOfInstance
from shared.types.worker.runners import (
LoadedRunnerStatus,
ReadyRunnerStatus,
# RunningRunnerStatus,
)
from worker.main import Worker
MASTER_NODE_ID = NodeId(uuid=UUID("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
NODE_A: Final[NodeId] = NodeId(uuid=UUID("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
NODE_B: Final[NodeId] = NodeId(uuid=UUID("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"))
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId()
@@ -26,20 +36,21 @@ MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId()
async def test_runner_spin_up(instance: Callable[[NodeId], Instance]):
# TODO.
return
node_id = NodeId()
logger = Logger('worker_test_logger')
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
@pytest.fixture
def user_message():
return "What is the capital of Japan?"
global_events = event_log_manager.global_events
async def test_runner_assigned(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
):
worker = Worker(node_id, State(), logger=logger, worker_events=global_events)
await worker.start()
worker, global_events = await worker_running(NODE_A)
instance_value = instance(node_id)
print(worker)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.INACTIVE
await global_events.append_events(
[
@@ -54,4 +65,153 @@ async def test_runner_spin_up(instance: Callable[[NodeId], Instance]):
await asyncio.sleep(0.1)
assert worker.assigned_runners
# Ensure the worker has taken the correct action
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, ReadyRunnerStatus)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) == 2
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, ReadyRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], ReadyRunnerStatus)
async def test_runner_assigned_active(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance],
chat_completion_task: Task
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) == 3
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, LoadedRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
# Ensure that the runner has been created and it can stream tokens.
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "tokyo" in full_response.lower(), (
f"Expected 'Tokyo' in response, but got: {full_response}"
)
async def test_runner_assigned_wrong_node(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value = instance(NODE_B, RUNNER_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
assert len(worker.assigned_runners) == 0
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) == 1
# No RunnerStatusUpdated event should be emitted
# Ensure state is correct
assert len(worker.state.runners) == 0
async def test_runner_unassigns(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
# already tested by test_runner_assigned_active
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
# Ensure the correct events have been emitted (creation)
events = await global_events.get_events_since(0)
assert len(events) == 3
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, LoadedRunnerStatus)
# Ensure state is correct
print(worker.state)
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
await global_events.append_events(
[
InstanceDeleted(instance_id=instance_value.instance_id)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
print(worker.state)
assert len(worker.assigned_runners) == 0
# Ensure the correct events have been emitted (deletion)
events = await global_events.get_events_since(0)
assert isinstance(events[-1].event, RunnerDeleted)
# After deletion, runner should be removed from state.runners
assert len(worker.state.runners) == 0

View File

@@ -88,7 +88,22 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
),
expected_op=None,
@@ -854,15 +869,9 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
case = test_cases[case.description]
node_id = NODE_A
initial_state = State(
node_status={node_id: NodeStatus.Idle},
instances={},
runners={},
tasks={},
)
logger = logging.getLogger("test_worker_plan")
worker = Worker(node_id=node_id, initial_state=initial_state, worker_events=None, logger=logger)
worker = Worker(node_id=node_id, worker_events=None, logger=logger)
path_downloaded_map: dict[str, bool] = {}

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Final, List, Optional, override
from uuid import UUID
from shared.models.model_cards import MODEL_CARDS, ModelCard
from shared.types.common import NodeId
@@ -23,13 +22,13 @@ from shared.types.worker.runners import (
from shared.types.worker.shards import PipelineShardMetadata
from worker.main import AssignedRunner
NODE_A: Final[NodeId] = NodeId(uuid=UUID("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
NODE_B: Final[NodeId] = NodeId(uuid=UUID("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"))
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId(uuid=UUID("cccccccc-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
RUNNER_1_ID: Final[RunnerId] = RunnerId("cccccccc-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
INSTANCE_1_ID: Final[InstanceId] = InstanceId()
RUNNER_2_ID: Final[RunnerId] = RunnerId(uuid=UUID("dddddddd-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
RUNNER_2_ID: Final[RunnerId] = RunnerId("dddddddd-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
INSTANCE_2_ID: Final[InstanceId] = InstanceId()
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
@@ -108,12 +107,12 @@ def make_model_meta(
) -> ModelMetadata:
model_card: ModelCard
for card in MODEL_CARDS.values():
if card.repo_id == model_id:
if card.model_id == model_id:
model_card = card
return ModelMetadata(
model_id=model_id,
pretty_name=model_card.id,
pretty_name=model_card.model_id,
storage_size_kilobytes=10**6,
n_layers=16,
)

View File

@@ -1,48 +0,0 @@
## Tests for worker state differentials
## When the worker state changes, this should be reflected by a worker intention.
import asyncio
from typing import Callable
from uuid import uuid4
import pytest
from shared.types.common import NodeId
from shared.types.state import State
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import Instance
from worker.main import Worker
@pytest.mark.asyncio
async def test_worker_runs_and_stops(worker: Worker):
await worker.start()
await asyncio.sleep(0.01)
assert worker._is_running, worker._task.exception() # type: ignore
await worker.stop()
await asyncio.sleep(0.01)
assert not worker._is_running # type: ignore
@pytest.mark.asyncio
async def test_worker_instance_added(worker: Worker, instance: Callable[[NodeId], Instance]):
await worker.start()
await asyncio.sleep(0.01)
worker.state.instances = {InstanceId(uuid4()): instance(worker.node_id)}
print(worker.state.instances)
def test_plan_noop(worker: Worker):
s = State(
node_status={
NodeId(uuid4()): NodeStatus.Idle
}
)
next_op = worker.plan(s)
assert next_op is None