diff --git a/master/api.py b/master/api.py index e2a8428d..387f2e5d 100644 --- a/master/api.py +++ b/master/api.py @@ -1,16 +1,21 @@ import asyncio import time from collections.abc import AsyncGenerator -from typing import List, Sequence, final +from typing import Callable, List, Sequence, final import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.models.model_cards import MODEL_CARDS +from shared.models.model_meta import get_model_meta from shared.types.api import ( ChatCompletionMessage, ChatCompletionResponse, + CreateInstanceResponse, + CreateInstanceTaskParams, + DeleteInstanceResponse, StreamingChoiceResponse, ) from shared.types.common import CommandId @@ -20,9 +25,14 @@ from shared.types.events.commands import ( ChatCompletionCommand, Command, CommandType, + CreateInstanceCommand, + DeleteInstanceCommand, ) from shared.types.events.components import EventFromEventLog +from shared.types.state import State from shared.types.tasks import ChatCompletionTaskParams +from shared.types.worker.common import InstanceId +from shared.types.worker.instances import Instance def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse: @@ -45,20 +55,21 @@ def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse: @final class API: - def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage) -> None: + def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, get_state: Callable[[], State]) -> None: self._app = FastAPI() self._setup_routes() self.command_buffer = command_buffer self.global_events = global_events + self.get_state = get_state def _setup_routes(self) -> None: # self._app.get("/topology/control_plane")(self.get_control_plane_topology) # self._app.get("/topology/data_plane")(self.get_data_plane_topology) # self._app.get("/instances/list")(self.list_instances) - # self._app.post("/instances/create")(self.create_instance) - # self._app.get("/instance/{instance_id}/read")(self.get_instance) - # self._app.delete("/instance/{instance_id}/delete")(self.remove_instance) + self._app.post("/instances/create")(self.create_instance) + self._app.get("/instance/{instance_id}")(self.get_instance) + self._app.delete("/instance/{instance_id}")(self.delete_instance) # self._app.get("/model/{model_id}/metadata")(self.get_model_data) # self._app.post("/model/{model_id}/instances")(self.get_instances_by_model) self._app.post("/v1/chat/completions")(self.chat_completions) @@ -80,11 +91,49 @@ class API: # def list_instances(self): # return {"message": "Hello, World!"} - # def create_instance(self, model_id: ModelId) -> InstanceId: ... + async def create_instance(self, payload: CreateInstanceTaskParams) -> CreateInstanceResponse: + if payload.model_id in MODEL_CARDS: + model_card = MODEL_CARDS[payload.model_id] + model_meta = model_card.metadata + else: + model_meta = await get_model_meta(payload.model_id) - # def get_instance(self, instance_id: InstanceId) -> Instance: ... + command = CreateInstanceCommand( + command_id=CommandId(), + command_type=CommandType.CREATE_INSTANCE, + model_meta=model_meta, + instance_id=InstanceId(), + ) + self.command_buffer.append(command) - # def remove_instance(self, instance_id: InstanceId) -> None: ... + return CreateInstanceResponse( + message="Command received.", + command_id=command.command_id, + model_meta=model_meta, + instance_id=command.instance_id, + ) + + def get_instance(self, instance_id: InstanceId) -> Instance: + state = self.get_state() + if instance_id not in state.instances: + raise HTTPException(status_code=404, detail="Instance not found") + return state.instances[instance_id] + + def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse: + if instance_id not in self.get_state().instances: + raise HTTPException(status_code=404, detail="Instance not found") + + command = DeleteInstanceCommand( + command_id=CommandId(), + command_type=CommandType.DELETE_INSTANCE, + instance_id=instance_id, + ) + self.command_buffer.append(command) + return DeleteInstanceResponse( + message="Command received.", + command_id=command.command_id, + instance_id=instance_id, + ) # def get_model_data(self, model_id: ModelId) -> ModelInfo: ... @@ -140,9 +189,10 @@ class API: def start_fastapi_server( command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, + get_state: Callable[[], State], host: str = "0.0.0.0", port: int = 8000, ): - api = API(command_buffer, global_events) + api = API(command_buffer, global_events, get_state) uvicorn.run(api.app, host=host, port=port) \ No newline at end of file diff --git a/master/forwarder_supervisor.py b/master/forwarder_supervisor.py index bdec1f7e..93a0bab0 100644 --- a/master/forwarder_supervisor.py +++ b/master/forwarder_supervisor.py @@ -106,8 +106,8 @@ class ForwarderSupervisor: self._process = await asyncio.create_subprocess_exec( str(self._binary_path), f'{pairs}', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stdout=None, + stderr=None, ) self._logger.info(f"Starting forwarder with forwarding pairs: {pairs}") diff --git a/master/main.py b/master/main.py index a253927d..acc1b122 100644 --- a/master/main.py +++ b/master/main.py @@ -7,92 +7,40 @@ from typing import List from master.api import start_fastapi_server from master.election_callback import ElectionCallbacks -from master.forwarder_supervisor import ForwarderSupervisor +from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor +from master.placement import get_instance_placements, get_transition_events from shared.apply import apply from shared.db.sqlite.config import EventLogConfig from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogManager -from shared.models.model_cards import MODEL_CARDS -from shared.models.model_meta import get_model_meta -from shared.types.common import CommandId, NodeId +from shared.node_id import get_node_id_keypair +from shared.types.common import NodeId from shared.types.events import ( - ChunkGenerated, - InstanceCreated, + Event, + NodePerformanceMeasured, TaskCreated, ) -from shared.types.events.chunks import TokenChunk from shared.types.events.commands import ( ChatCompletionCommand, Command, CreateInstanceCommand, DeleteInstanceCommand, ) +from shared.types.profiling import ( + MemoryPerformanceProfile, + NodePerformanceProfile, + SystemPerformanceProfile, +) from shared.types.state import State from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType -from shared.types.worker.common import InstanceId -from shared.types.worker.instances import ( - InstanceParams, - ShardAssignments, - TypeOfInstance, -) -from shared.types.worker.runners import RunnerId -from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata +from shared.types.worker.instances import Instance -## TODO: Hook this up properly -async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: CommandId): - model_id = "testmodelabc" - - for i in range(10): - await asyncio.sleep(0.1) - - # Create the event with proper types and consistent IDs - chunk_event = ChunkGenerated( - command_id=command_id, - chunk=TokenChunk( - command_id=command_id, # Use the same task_id - idx=i, - model=model_id, # Use the same model_id - text=f'text{i}', - token_id=i - ) - ) - - # ChunkGenerated needs to be cast to the expected BaseEvent type - await events_log.append_events( - [chunk_event], - origin=NodeId() - ) - - await asyncio.sleep(0.1) - - # Create the event with proper types and consistent IDs - chunk_event = ChunkGenerated( - command_id=command_id, - chunk=TokenChunk( - command_id=command_id, # Use the same task_id - idx=11, - model=model_id, # Use the same model_id - text=f'text{11}', - token_id=11, - finish_reason='stop' - ) - ) - - # ChunkGenerated needs to be cast to the expected BaseEvent type - await events_log.append_events( - [chunk_event], - origin=NodeId() - ) - -def get_node_id() -> NodeId: - return NodeId() # TODO - class Master: - def __init__(self, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger): + def __init__(self, node_id: NodeId, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger): + self.node_id = node_id self.command_buffer = command_buffer self.global_events = global_events - self.node_id = get_node_id() self.forwarder_supervisor = ForwarderSupervisor( forwarder_binary_path=forwarder_binary_path, logger=logger @@ -104,6 +52,62 @@ class Master: # TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events return State() + async def _run_event_loop_body(self) -> None: + if self.forwarder_supervisor.current_role == ForwarderRole.REPLICA: + await asyncio.sleep(0.1) + return + + next_events: list[Event] = [] + # 1. process commands + if len(self.command_buffer) > 0: + # for now we do one command at a time + next_command = self.command_buffer.pop(0) + self.logger.info(f"got command: {next_command}") + # TODO: validate the command + match next_command: + case ChatCompletionCommand(): + matching_instance: Instance | None = None + for instance in self.state.instances.values(): + if instance.shard_assignments.model_id == next_command.request_params.model: + matching_instance = instance + break + if not matching_instance: + raise ValueError(f"No instance found for model {next_command.request_params.model}") + + task_id = TaskId() + next_events.append(TaskCreated( + task_id=task_id, + task=ChatCompletionTask( + task_id=task_id, + task_type=TaskType.CHAT_COMPLETION, + instance_id=matching_instance.instance_id, + task_status=TaskStatus.PENDING, + task_params=next_command.request_params + ) + )) + case DeleteInstanceCommand(): + placement = get_instance_placements(next_command, self.state.topology, self.state.instances) + transition_events = get_transition_events(self.state.instances, placement) + next_events.extend(transition_events) + case CreateInstanceCommand(): + placement = get_instance_placements(next_command, self.state.topology, self.state.instances) + transition_events = get_transition_events(self.state.instances, placement) + next_events.extend(transition_events) + + await self.global_events.append_events(next_events, origin=self.node_id) + + # 2. get latest events + events = await self.global_events.get_events_since(self.state.last_event_applied_idx) + if len(events) == 0: + await asyncio.sleep(0.01) + return + + # 3. for each event, apply it to the state + for event_from_log in events: + self.state = apply(self.state, event_from_log) + + self.logger.info(f"state: {self.state.model_dump_json()}") + async def run(self): self.state = await self._get_state_snapshot() @@ -115,90 +119,41 @@ class Master: await self.election_callbacks.on_became_master() while True: - next_event = None - # 1. process commands - if len(self.command_buffer) > 0: - # for now we do one command at a time - next_command = self.command_buffer.pop(0) - self.logger.info(f"got command: {next_command}") - # TODO: validate the command - match next_command: - case ChatCompletionCommand(): - # 1. find a valid instance for this request, if none exists ERROR (TODO) - instance_id = InstanceId() - task_id = TaskId() - # 2. publish TaskCreated event (TODO) - next_event = TaskCreated( - task_id=task_id, - task=ChatCompletionTask( - task_id=task_id, - task_type=TaskType.CHAT_COMPLETION, - instance_id=instance_id, - task_status=TaskStatus.PENDING, - task_params=next_command.request_params - ) - ) - case DeleteInstanceCommand(): - # TODO - pass - case CreateInstanceCommand(): - if next_command.model_meta.model_id not in MODEL_CARDS: - raise ValueError(f"Model {next_command.model_meta.model_id} not supported.") - - # TODO: we should also support models that aren't in MODEL_CARDS - # if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface - if next_command.model_meta.model_id in MODEL_CARDS: - model_card = MODEL_CARDS[next_command.model_meta.model_id] - model_meta = model_card.metadata - else: - model_meta = await get_model_meta(next_command.model_meta.model_id) - - # TODO: how do we actually schedule an instance? TODO: @@@@@@π•Ύπ–Šπ–™π–@@@@@@ - next_event = InstanceCreated( - instance_id=InstanceId(), - instance_params=InstanceParams( - shard_assignments=ShardAssignments( - model_id=next_command.model_meta.model_id, - runner_to_shard={ - RunnerId(): PipelineShardMetadata( - model_meta=model_meta, - partition_strategy=PartitionStrategy.pipeline, - device_rank=0, - world_size=1, - start_layer=0, - end_layer=0, - n_layers=0 - ) - }, - node_to_runner={} - ), - hosts=[] - ), - instance_type=TypeOfInstance.ACTIVE, - ) - - if next_event is not None: - await self.global_events.append_events([next_event], origin=self.node_id) - - # 2. get latest events - events = await self.global_events.get_events_since(self.state.last_event_applied_idx) - if len(events) == 0: - await asyncio.sleep(0.01) - continue - - # 3. for each event, apply it to the state - for event_from_log in events: - self.state = apply(self.state, event_from_log) + try: + await self._run_event_loop_body() + except Exception as e: + self.logger.error(f"Error in _run_event_loop_body: {e}") + await asyncio.sleep(0.1) async def main(): logger = Logger(name='master_logger') + node_id_keypair = get_node_id_keypair() + node_id = NodeId(node_id_keypair.to_peer_id().to_base58()) event_log_manager = EventLogManager(EventLogConfig(), logger=logger) await event_log_manager.initialize() global_events: AsyncSQLiteEventStorage = event_log_manager.global_events + # TODO: this should be the resource monitor that does this + await global_events.append_events([NodePerformanceMeasured( + node_id=node_id, + node_profile=NodePerformanceProfile( + model_id="testmodelabc", + chip_id="testchipabc", + memory=MemoryPerformanceProfile( + ram_total=1000, + ram_available=1000, + swap_total=1000, + swap_available=1000 + ), + system=SystemPerformanceProfile( + flops_fp16=1000 + ) + ) + )], origin=node_id) + command_buffer: List[Command] = [] api_thread = threading.Thread( @@ -206,13 +161,14 @@ async def main(): args=( command_buffer, global_events, + lambda: master.state, ), daemon=True ) api_thread.start() logger.info('Running FastAPI server in a separate thread. Listening on port 8000.') - master = Master(command_buffer, global_events, forwarder_binary_path=Path("forwarder"), logger=logger) + master = Master(node_id, command_buffer, global_events, forwarder_binary_path=Path("./build/forwarder"), logger=logger) await master.run() if __name__ == "__main__": diff --git a/master/placement.py b/master/placement.py index 87d12c6e..82730472 100644 --- a/master/placement.py +++ b/master/placement.py @@ -13,15 +13,15 @@ from shared.topology import Topology from shared.types.events import Event, InstanceCreated, InstanceDeleted from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand from shared.types.worker.common import InstanceId -from shared.types.worker.instances import InstanceParams, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceStatus @singledispatch def get_instance_placements( command: CreateInstanceCommand, topology: Topology, - current_instances: dict[InstanceId, InstanceParams], -) -> dict[InstanceId, InstanceParams]: + current_instances: dict[InstanceId, Instance], +) -> dict[InstanceId, Instance]: available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances] if command.model_meta.model_id in available_models: raise ValueError(f"Instance for {command.model_meta.model_id} already exists") @@ -36,9 +36,11 @@ def get_instance_placements( shard_assignments = get_shard_assignments(command.model_meta, selected_cycle) - instance_id = InstanceId() + instance_id = command.instance_id target_instances = deepcopy(current_instances) - target_instances[instance_id] = InstanceParams( + target_instances[instance_id] = Instance( + instance_id=instance_id, + instance_type=InstanceStatus.ACTIVE, shard_assignments=shard_assignments, hosts=[] ) @@ -46,7 +48,7 @@ def get_instance_placements( @get_instance_placements.register -def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]: +def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, Instance]) -> dict[InstanceId, Instance]: target_instances = deepcopy(current_instances) if command.instance_id in target_instances: del target_instances[command.instance_id] @@ -55,19 +57,17 @@ def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dic def get_transition_events( - current_instances: Mapping[InstanceId, InstanceParams], - target_instances: Mapping[InstanceId, InstanceParams], + current_instances: Mapping[InstanceId, Instance], + target_instances: Mapping[InstanceId, Instance], ) -> Sequence[Event]: events: list[Event] = [] # find instances to create - for instance_id, instance_params in target_instances.items(): + for instance_id, instance in target_instances.items(): if instance_id not in current_instances: events.append( InstanceCreated( - instance_id=instance_id, - instance_params=instance_params, - instance_type=TypeOfInstance.ACTIVE + instance=instance, ) ) diff --git a/master/tests/test_master.py b/master/tests/test_master.py index 6a295652..f8fc6558 100644 --- a/master/tests/test_master.py +++ b/master/tests/test_master.py @@ -11,6 +11,7 @@ from shared.db.sqlite.config import EventLogConfig from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogManager from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams +from shared.types.common import NodeId from shared.types.events import TaskCreated from shared.types.events.commands import ChatCompletionCommand, Command, CommandId from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType @@ -36,7 +37,8 @@ async def test_master(): forwarder_binary_path = _create_forwarder_dummy_binary() - master = Master(command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger) + node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") + master = Master(node_id, command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger) asyncio.create_task(master.run()) command_buffer.append( diff --git a/master/tests/test_placement.py b/master/tests/test_placement.py index cf105b97..3218297e 100644 --- a/master/tests/test_placement.py +++ b/master/tests/test_placement.py @@ -12,7 +12,7 @@ from shared.types.events.commands import CreateInstanceCommand from shared.types.models import ModelMetadata from shared.types.topology import Connection, Node from shared.types.worker.common import InstanceId -from shared.types.worker.instances import InstanceParams +from shared.types.worker.instances import Instance, InstanceStatus from shared.types.worker.runners import ShardAssignments @@ -21,8 +21,10 @@ def topology() -> Topology: return Topology() @pytest.fixture -def instance_params() -> InstanceParams: - return InstanceParams( +def instance() -> Instance: + return Instance( + instance_id=InstanceId(), + instance_type=InstanceStatus.ACTIVE, shard_assignments=ShardAssignments( model_id="test-model", runner_to_shard={}, @@ -43,7 +45,8 @@ def model_meta() -> ModelMetadata: def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand: return CreateInstanceCommand( command_id=CommandId(), - model_meta=model_meta + model_meta=model_meta, + instance_id=InstanceId(), ) @@ -66,7 +69,8 @@ def test_get_instance_placements_create_instance( create_instance_command = CreateInstanceCommand( command_id=CommandId(), - model_meta=model_meta + model_meta=model_meta, + instance_id=InstanceId(), ) node_id_a = NodeId() node_id_b = NodeId() @@ -84,16 +88,16 @@ def test_get_instance_placements_create_instance( # assert assert len(placements) == 1 instance_id = list(placements.keys())[0] - instance_params = placements[instance_id] - assert instance_params.shard_assignments.model_id == model_meta.model_id + instance = placements[instance_id] + assert instance.shard_assignments.model_id == model_meta.model_id - runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a] - runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b] - runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c] + runner_id_a = instance.shard_assignments.node_to_runner[node_id_a] + runner_id_b = instance.shard_assignments.node_to_runner[node_id_b] + runner_id_c = instance.shard_assignments.node_to_runner[node_id_c] - shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a] - shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b] - shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c] + shard_a = instance.shard_assignments.runner_to_shard[runner_id_a] + shard_b = instance.shard_assignments.runner_to_shard[runner_id_b] + shard_c = instance.shard_assignments.runner_to_shard[runner_id_c] assert shard_a.end_layer - shard_a.start_layer == expected_layers[0] assert shard_b.end_layer - shard_b.start_layer == expected_layers[1] @@ -105,14 +109,14 @@ def test_get_instance_placements_create_instance( assert shards_sorted[-1].end_layer == total_layers -def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams): +def test_get_transition_events_no_change(topology: Topology, instance: Instance): # arrange instance_id = InstanceId() current_instances = { - instance_id: instance_params + instance_id: instance } target_instances = { - instance_id: instance_params + instance_id: instance } # act @@ -122,12 +126,12 @@ def test_get_transition_events_no_change(topology: Topology, instance_params: In assert len(events) == 0 -def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams): +def test_get_transition_events_create_instance(topology: Topology, instance: Instance): # arrange instance_id = InstanceId() - current_instances: dict[InstanceId, InstanceParams] = {} - target_instances: dict[InstanceId, InstanceParams] = { - instance_id: instance_params + current_instances: dict[InstanceId, Instance] = {} + target_instances: dict[InstanceId, Instance] = { + instance_id: instance } # act @@ -138,13 +142,13 @@ def test_get_transition_events_create_instance(topology: Topology, instance_para assert events[0].event_type == _EventType.InstanceCreated -def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams): +def test_get_transition_events_delete_instance(topology: Topology, instance: Instance): # arrange instance_id = InstanceId() - current_instances: dict[InstanceId, InstanceParams] = { - instance_id: instance_params + current_instances: dict[InstanceId, Instance] = { + instance_id: instance } - target_instances: dict[InstanceId, InstanceParams] = {} + target_instances: dict[InstanceId, Instance] = {} # act events = get_transition_events(current_instances, target_instances) diff --git a/networking/forwarder/src/sqlite.go b/networking/forwarder/src/sqlite.go index 7a449f61..2f52d693 100644 --- a/networking/forwarder/src/sqlite.go +++ b/networking/forwarder/src/sqlite.go @@ -281,7 +281,7 @@ func (c *sqliteConnector) getLatestRowIds() (map[SourceKey]int64, error) { } selectCols := strings.Join(keyCols, ", ") - query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" GROUP BY %s`, selectCols, rowIDCol, c.tableName, selectCols) + query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" WHERE %s IS NOT NULL GROUP BY %s`, selectCols, rowIDCol, c.tableName, rowIDCol, selectCols) rows, err := c.db.Query(query) if err != nil { diff --git a/shared/apply/apply.py b/shared/apply/apply.py index fcd8e400..8a333aba 100644 --- a/shared/apply/apply.py +++ b/shared/apply/apply.py @@ -28,7 +28,7 @@ from shared.types.profiling import NodePerformanceProfile from shared.types.state import State from shared.types.tasks import Task, TaskId from shared.types.worker.common import NodeStatus, RunnerId -from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceId, InstanceStatus from shared.types.worker.runners import RunnerStatus S = TypeVar("S", bound=State) @@ -62,8 +62,8 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State: @event_apply.register(InstanceCreated) def apply_instance_created(event: InstanceCreated, state: State) -> State: - instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type) - new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance} + instance = event.instance + new_instances: Mapping[InstanceId, Instance] = {**state.instances, instance.instance_id: instance} return state.model_copy(update={"instances": new_instances}) @event_apply.register(InstanceActivated) @@ -71,8 +71,8 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State: if event.instance_id not in state.instances: return state - updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE}) - new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance} + updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.ACTIVE}) + new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance} return state.model_copy(update={"instances": new_instances}) @event_apply.register(InstanceDeactivated) @@ -80,13 +80,13 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat if event.instance_id not in state.instances: return state - updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE}) - new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance} + updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.INACTIVE}) + new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance} return state.model_copy(update={"instances": new_instances}) @event_apply.register(InstanceDeleted) def apply_instance_deleted(event: InstanceDeleted, state: State) -> State: - new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id} + new_instances: Mapping[InstanceId, Instance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id} return state.model_copy(update={"instances": new_instances}) @event_apply.register(InstanceReplacedAtomically) diff --git a/shared/types/api.py b/shared/types/api.py index 6b235c16..98d99468 100644 --- a/shared/types/api.py +++ b/shared/types/api.py @@ -3,6 +3,9 @@ from typing import Any, Literal from pydantic import BaseModel from shared.openai_compat import FinishReason +from shared.types.common import CommandId +from shared.types.models import ModelMetadata +from shared.types.worker.instances import InstanceId class ChatCompletionMessage(BaseModel): @@ -97,8 +100,20 @@ class ChatCompletionTaskParams(BaseModel): parallel_tool_calls: bool | None = None user: str | None = None -class RequestInstanceTaskParams(BaseModel): +class CreateInstanceTaskParams(BaseModel): + # TODO: in future the user could specify a specific Instance, not just a model_id model_id: str class DeleteInstanceTaskParams(BaseModel): instance_id: str + +class CreateInstanceResponse(BaseModel): + message: str + command_id: CommandId + model_meta: ModelMetadata + instance_id: InstanceId + +class DeleteInstanceResponse(BaseModel): + message: str + command_id: CommandId + instance_id: InstanceId diff --git a/shared/types/events/_events.py b/shared/types/events/_events.py index 5fe7bd12..e28f55c3 100644 --- a/shared/types/events/_events.py +++ b/shared/types/events/_events.py @@ -18,7 +18,7 @@ from shared.types.common import NodeId from shared.types.events.chunks import CommandId, GenerationChunk from shared.types.tasks import Task, TaskId, TaskStatus from shared.types.worker.common import InstanceId, NodeStatus -from shared.types.worker.instances import InstanceParams, TypeOfInstance +from shared.types.worker.instances import Instance from shared.types.worker.runners import RunnerId, RunnerStatus if TYPE_CHECKING: @@ -114,9 +114,7 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]): class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]): event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated - instance_id: InstanceId - instance_params: InstanceParams - instance_type: TypeOfInstance + instance: Instance class InstanceActivated(_BaseEvent[_EventType.InstanceActivated]): diff --git a/shared/types/events/commands.py b/shared/types/events/commands.py index ae17100d..6f2b98eb 100644 --- a/shared/types/events/commands.py +++ b/shared/types/events/commands.py @@ -7,7 +7,8 @@ from shared.types.api import ChatCompletionTaskParams from shared.types.common import CommandId from shared.types.events import Event from shared.types.models import ModelMetadata -from shared.types.state import InstanceId, State +from shared.types.state import State +from shared.types.worker.common import InstanceId # TODO: We need to have a distinction between create instance and spin up instance. @@ -30,6 +31,7 @@ class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]): class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]): command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE model_meta: ModelMetadata + instance_id: InstanceId class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]): diff --git a/shared/types/request.py b/shared/types/request.py index 915e9ce5..49cbbf31 100644 --- a/shared/types/request.py +++ b/shared/types/request.py @@ -2,8 +2,8 @@ from pydantic import BaseModel from shared.types.api import ( ChatCompletionTaskParams, + CreateInstanceTaskParams, DeleteInstanceTaskParams, - RequestInstanceTaskParams, ) from shared.types.events import CommandId @@ -12,12 +12,12 @@ class ChatCompletionCommand(BaseModel): command_id: CommandId command_params: ChatCompletionTaskParams -class RequestInstanceCommand(BaseModel): +class CreateInstanceCommand(BaseModel): command_id: CommandId - command_params: RequestInstanceTaskParams + command_params: CreateInstanceTaskParams class DeleteInstanceCommand(BaseModel): command_id: CommandId command_params: DeleteInstanceTaskParams -type Command = ChatCompletionCommand | RequestInstanceCommand | DeleteInstanceCommand +type Command = ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand diff --git a/shared/types/state.py b/shared/types/state.py index 769ad319..7736b838 100644 --- a/shared/types/state.py +++ b/shared/types/state.py @@ -7,14 +7,14 @@ from shared.types.common import NodeId from shared.types.profiling import NodePerformanceProfile from shared.types.tasks import Task, TaskId from shared.types.worker.common import InstanceId, NodeStatus -from shared.types.worker.instances import BaseInstance +from shared.types.worker.instances import Instance from shared.types.worker.runners import RunnerId, RunnerStatus class State(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) node_status: Mapping[NodeId, NodeStatus] = {} - instances: Mapping[InstanceId, BaseInstance] = {} + instances: Mapping[InstanceId, Instance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} tasks: Mapping[TaskId, Task] = {} node_profiles: Mapping[NodeId, NodePerformanceProfile] = {} diff --git a/shared/types/worker/instances.py b/shared/types/worker/instances.py index 50047adc..4bfa92af 100644 --- a/shared/types/worker/instances.py +++ b/shared/types/worker/instances.py @@ -9,20 +9,12 @@ from shared.types.worker.runners import ( ) -class TypeOfInstance(str, Enum): - ACTIVE = "active" - INACTIVE = "inactive" +class InstanceStatus(str, Enum): + ACTIVE = "ACTIVE" + INACTIVE = "INACTIVE" - -class InstanceParams(BaseModel): +class Instance(BaseModel): + instance_id: InstanceId + instance_type: InstanceStatus shard_assignments: ShardAssignments hosts: list[Host] - - -class BaseInstance(BaseModel): - instance_params: InstanceParams - instance_type: TypeOfInstance - - -class Instance(BaseInstance): - instance_id: InstanceId diff --git a/worker/main.py b/worker/main.py index 16efa7ec..8a078c6a 100644 --- a/worker/main.py +++ b/worker/main.py @@ -28,7 +28,7 @@ from shared.types.worker.downloads import ( DownloadOngoing, DownloadProgressData, ) -from shared.types.worker.instances import TypeOfInstance +from shared.types.worker.instances import InstanceStatus from shared.types.worker.mlx import Host from shared.types.worker.ops import ( AssignRunnerOp, @@ -323,29 +323,29 @@ class Worker: runner_ids: list[RunnerId] = [ runner_id for instance in state.instances.values() - for runner_id in instance.instance_params.shard_assignments.runner_to_shard + for runner_id in instance.shard_assignments.runner_to_shard ] if runner_id not in runner_ids: return UnassignRunnerOp(runner_id=runner_id) # Then spin down active runners for _instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): if node_id != self.node_id: continue # We spin down a runner if it's meant to be inactive and it's Loaded. if runner_id in self.assigned_runners and \ isinstance(self.assigned_runners[runner_id].status, LoadedRunnerStatus) and \ - instance.instance_type == TypeOfInstance.INACTIVE: + instance.instance_type == InstanceStatus.INACTIVE: return RunnerDownOp(runner_id=runner_id) # If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down # TODO: We need to limit number of retries if we keep failing. for _instance_id, instance in state.instances.items(): - if self.node_id in instance.instance_params.shard_assignments.node_to_runner: + if self.node_id in instance.shard_assignments.node_to_runner: other_node_in_instance_has_failed = False - for runner_id in instance.instance_params.shard_assignments.runner_to_shard: + for runner_id in instance.shard_assignments.runner_to_shard: if runner_id in state.runners and \ isinstance(state.runners[runner_id], FailedRunnerStatus) and \ runner_id not in self.assigned_runners: @@ -353,28 +353,28 @@ class Worker: if other_node_in_instance_has_failed: # Spin down *our* runner - return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id]) + return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) # If we are failed - and *all of the other nodes have spun down* - then we can spin down too. for _instance_id, instance in state.instances.items(): - if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \ - instance.instance_params.shard_assignments.node_to_runner[self.node_id] in state.runners and \ - isinstance(state.runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus): + if self.node_id in instance.shard_assignments.node_to_runner and \ + instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \ + isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus): num_spundown_nodes = 0 - for runner_id in instance.instance_params.shard_assignments.runner_to_shard: + for runner_id in instance.shard_assignments.runner_to_shard: if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \ runner_id not in self.assigned_runners: num_spundown_nodes += 1 - if num_spundown_nodes == next(iter(instance.instance_params.shard_assignments.runner_to_shard.values())).world_size - 1: + if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1: # All the other nodes are spun down - so now we can spin down too. # This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away - return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id]) + return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) # Then assign runners we do want for instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): if node_id != self.node_id: continue @@ -382,15 +382,15 @@ class Worker: return AssignRunnerOp( runner_id=runner_id, instance_id=instance_id, - shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id], - hosts=instance.instance_params.hosts + shard_metadata=instance.shard_assignments.runner_to_shard[runner_id], + hosts=instance.hosts ) # Then make sure things are downloading. for instance_id, instance in state.instances.items(): # We should already have asserted that this runner exists # If it didn't exist then we return a assign_runner op. - for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): if node_id != self.node_id: continue assert runner_id in self.assigned_runners @@ -404,29 +404,29 @@ class Worker: return DownloadOp( runner_id=runner_id, instance_id=instance_id, - shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id], - hosts=instance.instance_params.hosts + shard_metadata=instance.shard_assignments.runner_to_shard[runner_id], + hosts=instance.hosts ) # Then spin up 'ready' runners that should be active for _instance_id, instance in state.instances.items(): - if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \ - self.assigned_runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]].runner is None and \ - instance.instance_type == TypeOfInstance.ACTIVE: + if self.node_id in instance.shard_assignments.node_to_runner and \ + self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].runner is None and \ + instance.instance_type == InstanceStatus.ACTIVE: # We are part of this instance, we want it up but it hasn't been spun up yet. # Need to assert all other runners are ready before we can spin up. ready_to_spin = True - for runner_id in instance.instance_params.shard_assignments.node_to_runner.values(): + for runner_id in instance.shard_assignments.node_to_runner.values(): if state.runners[runner_id].runner_status != RunnerStatusType.Ready: ready_to_spin = False if ready_to_spin: - return RunnerUpOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id]) + return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) # Then make sure things are running based on tasks. for instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): if node_id != self.node_id: continue assert runner_id in self.assigned_runners @@ -443,7 +443,7 @@ class Worker: # so let's check that all the other runners are running - ready for us to fire the prompt. running_runner_count = 0 for other_runner_id, other_runner_status in state.runners.items(): - if other_runner_id in instance.instance_params.shard_assignments.node_to_runner.values() and \ + if other_runner_id in instance.shard_assignments.node_to_runner.values() and \ isinstance(other_runner_status, RunningRunnerStatus): running_runner_count += 1 diff --git a/worker/tests/conftest.py b/worker/tests/conftest.py index de79fd87..70f230b2 100644 --- a/worker/tests/conftest.py +++ b/worker/tests/conftest.py @@ -19,7 +19,7 @@ from shared.types.tasks import ( TaskType, ) from shared.types.worker.common import InstanceId, NodeStatus -from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceStatus from shared.types.worker.mlx import Host from shared.types.worker.ops import ( AssignRunnerOp, @@ -140,15 +140,11 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h node_to_runner={node_id: runner_id} ) - instance_params = InstanceParams( - shard_assignments=shard_assignments, - hosts=hosts_one - ) - return Instance( instance_id=InstanceId(), - instance_params=instance_params, - instance_type=TypeOfInstance.ACTIVE + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts_one ) return _instance @@ -166,13 +162,13 @@ async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId instance_obj: Instance = instance(worker.node_id, RunnerId()) # Extract runner_id from shard assignments - runner_id = next(iter(instance_obj.instance_params.shard_assignments.runner_to_shard)) + runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard)) # Assign the runner assign_op = AssignRunnerOp( runner_id=runner_id, - shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.instance_params.hosts, + shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], + hosts=instance_obj.hosts, instance_id=instance_obj.instance_id, ) diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py index 593ee920..eb791f2d 100644 --- a/worker/tests/test_worker_handlers.py +++ b/worker/tests/test_worker_handlers.py @@ -46,8 +46,8 @@ async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], assign_op = AssignRunnerOp( runner_id=runner_id, - shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.instance_params.hosts, + shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], + hosts=instance_obj.hosts, instance_id=instance_obj.instance_id, ) @@ -138,8 +138,8 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId, download_op = DownloadOp( instance_id=instance_obj.instance_id, runner_id=runner_id, - shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.instance_params.hosts, + shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], + hosts=instance_obj.hosts, ) events: list[Event] = [] diff --git a/worker/tests/test_worker_integration.py b/worker/tests/test_worker_integration.py index fa9b49b4..f83b1013 100644 --- a/worker/tests/test_worker_integration.py +++ b/worker/tests/test_worker_integration.py @@ -15,7 +15,7 @@ from shared.types.events.chunks import TokenChunk from shared.types.models import ModelId from shared.types.tasks import Task, TaskId from shared.types.worker.common import InstanceId, RunnerId -from shared.types.worker.instances import Instance, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceStatus from shared.types.worker.runners import ( LoadedRunnerStatus, ReadyRunnerStatus, @@ -50,14 +50,12 @@ async def test_runner_assigned( print(worker) instance_value: Instance = instance(NODE_A, RUNNER_1_ID) - instance_value.instance_type = TypeOfInstance.INACTIVE + instance_value.instance_type = InstanceStatus.INACTIVE await global_events.append_events( [ InstanceCreated( - instance_id=instance_value.instance_id, - instance_params=instance_value.instance_params, - instance_type=instance_value.instance_type + instance=instance_value ) ], origin=MASTER_NODE_ID @@ -87,14 +85,12 @@ async def test_runner_assigned_active( worker, global_events = await worker_running(NODE_A) instance_value: Instance = instance(NODE_A, RUNNER_1_ID) - instance_value.instance_type = TypeOfInstance.ACTIVE + instance_value.instance_type = InstanceStatus.ACTIVE await global_events.append_events( [ InstanceCreated( - instance_id=instance_value.instance_id, - instance_params=instance_value.instance_params, - instance_type=instance_value.instance_type + instance=instance_value ) ], origin=MASTER_NODE_ID @@ -141,9 +137,7 @@ async def test_runner_assigned_wrong_node( await global_events.append_events( [ InstanceCreated( - instance_id=instance_value.instance_id, - instance_params=instance_value.instance_params, - instance_type=instance_value.instance_type + instance=instance_value ) ], origin=MASTER_NODE_ID @@ -168,14 +162,12 @@ async def test_runner_unassigns( worker, global_events = await worker_running(NODE_A) instance_value: Instance = instance(NODE_A, RUNNER_1_ID) - instance_value.instance_type = TypeOfInstance.ACTIVE + instance_value.instance_type = InstanceStatus.ACTIVE await global_events.append_events( [ InstanceCreated( - instance_id=instance_value.instance_id, - instance_params=instance_value.instance_params, - instance_type=instance_value.instance_type + instance=instance_value ) ], origin=MASTER_NODE_ID diff --git a/worker/tests/test_worker_plan.py b/worker/tests/test_worker_plan.py index 4db3f85d..3da7c8c8 100644 --- a/worker/tests/test_worker_plan.py +++ b/worker/tests/test_worker_plan.py @@ -16,7 +16,7 @@ from shared.types.tasks import ( ) from shared.types.worker.common import NodeStatus from shared.types.worker.downloads import DownloadPending -from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceStatus from shared.types.worker.ops import ( AssignRunnerOp, DownloadOp, @@ -90,9 +90,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.INACTIVE, + instance_type=InstanceStatus.INACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -101,7 +100,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: make_downloading_status(NODE_A)}, @@ -124,9 +122,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.INACTIVE, + instance_type=InstanceStatus.INACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -135,7 +132,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: make_downloading_status(NODE_A)}, @@ -158,9 +154,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.INACTIVE, + instance_type=InstanceStatus.INACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -169,7 +164,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: ReadyRunnerStatus()}, @@ -184,9 +178,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, # Either active or inactive should yield the same. + instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same. instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -195,7 +188,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: AssignedRunnerStatus()}, @@ -245,9 +237,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -256,7 +247,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: AssignedRunnerStatus()}, @@ -291,9 +281,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -302,7 +291,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: ReadyRunnerStatus()}, @@ -337,9 +325,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -349,7 +336,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))}, @@ -382,9 +368,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -394,7 +379,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()}, @@ -418,9 +402,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.INACTIVE, + instance_type=InstanceStatus.INACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -429,7 +412,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus()}, @@ -453,9 +435,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.INACTIVE, + instance_type=InstanceStatus.INACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -464,7 +445,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: FailedRunnerStatus()}, @@ -488,9 +468,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -499,7 +478,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus()}, @@ -542,9 +520,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -554,7 +531,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, @@ -587,9 +563,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -599,7 +574,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, @@ -644,9 +618,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -656,7 +629,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()}, @@ -701,9 +673,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -713,7 +684,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()}, @@ -737,9 +707,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -748,7 +717,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: FailedRunnerStatus()}, @@ -781,9 +749,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( shard_assignments=ShardAssignments( model_id=MODEL_A_ID, runner_to_shard={ @@ -793,7 +760,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), hosts=[] - ), ) }, runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, @@ -825,19 +791,17 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, instances={ INSTANCE_1_ID: Instance( - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, instance_id=INSTANCE_1_ID, - instance_params=InstanceParams( - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] + shard_assignments=ShardAssignments( + model_id=MODEL_A_ID, + runner_to_shard={ + RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), + RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) + }, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} ), + hosts=[] ) }, runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()}, @@ -884,7 +848,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon if len(case.state.instances) == 1: instance_id = next(iter(case.state.instances)) - shard_assignments = case.state.instances[instance_id].instance_params.shard_assignments + shard_assignments = case.state.instances[instance_id].shard_assignments shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id] # Only add this runner if it belongs to our node diff --git a/worker/tests/test_worker_plan_utils.py b/worker/tests/test_worker_plan_utils.py index 71b90867..b0c81fad 100644 --- a/worker/tests/test_worker_plan_utils.py +++ b/worker/tests/test_worker_plan_utils.py @@ -11,7 +11,7 @@ from shared.types.state import State from shared.types.tasks import TaskId from shared.types.worker.common import InstanceId, NodeStatus, RunnerId from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData -from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance +from shared.types.worker.instances import Instance, InstanceStatus from shared.types.worker.ops import RunnerOp from shared.types.worker.runners import ( AssignedRunnerStatus, @@ -148,14 +148,11 @@ def create_worker_state( runner_to_shard={runner_id: shard_metadata}, node_to_runner={node_id: runner_id}, ) - instance_params = InstanceParams( - shard_assignments=shard_assignments, - hosts=[], - ) instance = Instance( instance_id=instance_id, - instance_params=instance_params, - instance_type=TypeOfInstance.ACTIVE, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=[], ) instances[instance_id] = instance @@ -198,14 +195,11 @@ def make_instance( runner_to_shard=runner_to_shard, node_to_runner=node_to_runner, ) - instance_params = InstanceParams( + return Instance( + instance_id=instance_id, + instance_type=InstanceStatus.ACTIVE, shard_assignments=shard_assignments, hosts=[], ) - return Instance( - instance_id=instance_id, - instance_params=instance_params, - instance_type=TypeOfInstance.ACTIVE, - ) ### For worker plan tests \ No newline at end of file