This commit is contained in:
Alex Cheema
2025-07-25 13:10:29 +01:00
committed by GitHub
parent 6f8e3419d5
commit a241c92dd1
20 changed files with 324 additions and 359 deletions

View File

@@ -1,16 +1,21 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import List, Sequence, final
from typing import Callable, List, Sequence, final
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.api import (
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceResponse,
CreateInstanceTaskParams,
DeleteInstanceResponse,
StreamingChoiceResponse,
)
from shared.types.common import CommandId
@@ -20,9 +25,14 @@ from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CommandType,
CreateInstanceCommand,
DeleteInstanceCommand,
)
from shared.types.events.components import EventFromEventLog
from shared.types.state import State
from shared.types.tasks import ChatCompletionTaskParams
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
@@ -45,20 +55,21 @@ def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
@final
class API:
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage) -> None:
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, get_state: Callable[[], State]) -> None:
self._app = FastAPI()
self._setup_routes()
self.command_buffer = command_buffer
self.global_events = global_events
self.get_state = get_state
def _setup_routes(self) -> None:
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
# self._app.get("/instances/list")(self.list_instances)
# self._app.post("/instances/create")(self.create_instance)
# self._app.get("/instance/{instance_id}/read")(self.get_instance)
# self._app.delete("/instance/{instance_id}/delete")(self.remove_instance)
self._app.post("/instances/create")(self.create_instance)
self._app.get("/instance/{instance_id}")(self.get_instance)
self._app.delete("/instance/{instance_id}")(self.delete_instance)
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
self._app.post("/v1/chat/completions")(self.chat_completions)
@@ -80,11 +91,49 @@ class API:
# def list_instances(self):
# return {"message": "Hello, World!"}
# def create_instance(self, model_id: ModelId) -> InstanceId: ...
async def create_instance(self, payload: CreateInstanceTaskParams) -> CreateInstanceResponse:
if payload.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[payload.model_id]
model_meta = model_card.metadata
else:
model_meta = await get_model_meta(payload.model_id)
# def get_instance(self, instance_id: InstanceId) -> Instance: ...
command = CreateInstanceCommand(
command_id=CommandId(),
command_type=CommandType.CREATE_INSTANCE,
model_meta=model_meta,
instance_id=InstanceId(),
)
self.command_buffer.append(command)
# def remove_instance(self, instance_id: InstanceId) -> None: ...
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
instance_id=command.instance_id,
)
def get_instance(self, instance_id: InstanceId) -> Instance:
state = self.get_state()
if instance_id not in state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
return state.instances[instance_id]
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
if instance_id not in self.get_state().instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = DeleteInstanceCommand(
command_id=CommandId(),
command_type=CommandType.DELETE_INSTANCE,
instance_id=instance_id,
)
self.command_buffer.append(command)
return DeleteInstanceResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
@@ -140,9 +189,10 @@ class API:
def start_fastapi_server(
command_buffer: List[Command],
global_events: AsyncSQLiteEventStorage,
get_state: Callable[[], State],
host: str = "0.0.0.0",
port: int = 8000,
):
api = API(command_buffer, global_events)
api = API(command_buffer, global_events, get_state)
uvicorn.run(api.app, host=host, port=port)

View File

@@ -106,8 +106,8 @@ class ForwarderSupervisor:
self._process = await asyncio.create_subprocess_exec(
str(self._binary_path),
f'{pairs}',
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
stdout=None,
stderr=None,
)
self._logger.info(f"Starting forwarder with forwarding pairs: {pairs}")

View File

@@ -7,92 +7,40 @@ from typing import List
from master.api import start_fastapi_server
from master.election_callback import ElectionCallbacks
from master.forwarder_supervisor import ForwarderSupervisor
from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
from master.placement import get_instance_placements, get_transition_events
from shared.apply import apply
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.common import CommandId, NodeId
from shared.node_id import get_node_id_keypair
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
InstanceCreated,
Event,
NodePerformanceMeasured,
TaskCreated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CreateInstanceCommand,
DeleteInstanceCommand,
)
from shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from shared.types.state import State
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import (
InstanceParams,
ShardAssignments,
TypeOfInstance,
)
from shared.types.worker.runners import RunnerId
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
from shared.types.worker.instances import Instance
## TODO: Hook this up properly
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: CommandId):
model_id = "testmodelabc"
for i in range(10):
await asyncio.sleep(0.1)
# Create the event with proper types and consistent IDs
chunk_event = ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
command_id=command_id, # Use the same task_id
idx=i,
model=model_id, # Use the same model_id
text=f'text{i}',
token_id=i
)
)
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
[chunk_event],
origin=NodeId()
)
await asyncio.sleep(0.1)
# Create the event with proper types and consistent IDs
chunk_event = ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
command_id=command_id, # Use the same task_id
idx=11,
model=model_id, # Use the same model_id
text=f'text{11}',
token_id=11,
finish_reason='stop'
)
)
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
[chunk_event],
origin=NodeId()
)
def get_node_id() -> NodeId:
return NodeId() # TODO
class Master:
def __init__(self, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
def __init__(self, node_id: NodeId, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
self.node_id = node_id
self.command_buffer = command_buffer
self.global_events = global_events
self.node_id = get_node_id()
self.forwarder_supervisor = ForwarderSupervisor(
forwarder_binary_path=forwarder_binary_path,
logger=logger
@@ -104,6 +52,62 @@ class Master:
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
return State()
async def _run_event_loop_body(self) -> None:
if self.forwarder_supervisor.current_role == ForwarderRole.REPLICA:
await asyncio.sleep(0.1)
return
next_events: list[Event] = []
# 1. process commands
if len(self.command_buffer) > 0:
# for now we do one command at a time
next_command = self.command_buffer.pop(0)
self.logger.info(f"got command: {next_command}")
# TODO: validate the command
match next_command:
case ChatCompletionCommand():
matching_instance: Instance | None = None
for instance in self.state.instances.values():
if instance.shard_assignments.model_id == next_command.request_params.model:
matching_instance = instance
break
if not matching_instance:
raise ValueError(f"No instance found for model {next_command.request_params.model}")
task_id = TaskId()
next_events.append(TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id=task_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=matching_instance.instance_id,
task_status=TaskStatus.PENDING,
task_params=next_command.request_params
)
))
case DeleteInstanceCommand():
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
transition_events = get_transition_events(self.state.instances, placement)
next_events.extend(transition_events)
case CreateInstanceCommand():
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
transition_events = get_transition_events(self.state.instances, placement)
next_events.extend(transition_events)
await self.global_events.append_events(next_events, origin=self.node_id)
# 2. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
if len(events) == 0:
await asyncio.sleep(0.01)
return
# 3. for each event, apply it to the state
for event_from_log in events:
self.state = apply(self.state, event_from_log)
self.logger.info(f"state: {self.state.model_dump_json()}")
async def run(self):
self.state = await self._get_state_snapshot()
@@ -115,90 +119,41 @@ class Master:
await self.election_callbacks.on_became_master()
while True:
next_event = None
# 1. process commands
if len(self.command_buffer) > 0:
# for now we do one command at a time
next_command = self.command_buffer.pop(0)
self.logger.info(f"got command: {next_command}")
# TODO: validate the command
match next_command:
case ChatCompletionCommand():
# 1. find a valid instance for this request, if none exists ERROR (TODO)
instance_id = InstanceId()
task_id = TaskId()
# 2. publish TaskCreated event (TODO)
next_event = TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id=task_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=instance_id,
task_status=TaskStatus.PENDING,
task_params=next_command.request_params
)
)
case DeleteInstanceCommand():
# TODO
pass
case CreateInstanceCommand():
if next_command.model_meta.model_id not in MODEL_CARDS:
raise ValueError(f"Model {next_command.model_meta.model_id} not supported.")
# TODO: we should also support models that aren't in MODEL_CARDS
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
if next_command.model_meta.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[next_command.model_meta.model_id]
model_meta = model_card.metadata
else:
model_meta = await get_model_meta(next_command.model_meta.model_id)
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
next_event = InstanceCreated(
instance_id=InstanceId(),
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=next_command.model_meta.model_id,
runner_to_shard={
RunnerId(): PipelineShardMetadata(
model_meta=model_meta,
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=0,
n_layers=0
)
},
node_to_runner={}
),
hosts=[]
),
instance_type=TypeOfInstance.ACTIVE,
)
if next_event is not None:
await self.global_events.append_events([next_event], origin=self.node_id)
# 2. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
if len(events) == 0:
await asyncio.sleep(0.01)
continue
# 3. for each event, apply it to the state
for event_from_log in events:
self.state = apply(self.state, event_from_log)
try:
await self._run_event_loop_body()
except Exception as e:
self.logger.error(f"Error in _run_event_loop_body: {e}")
await asyncio.sleep(0.1)
async def main():
logger = Logger(name='master_logger')
node_id_keypair = get_node_id_keypair()
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
# TODO: this should be the resource monitor that does this
await global_events.append_events([NodePerformanceMeasured(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="testmodelabc",
chip_id="testchipabc",
memory=MemoryPerformanceProfile(
ram_total=1000,
ram_available=1000,
swap_total=1000,
swap_available=1000
),
system=SystemPerformanceProfile(
flops_fp16=1000
)
)
)], origin=node_id)
command_buffer: List[Command] = []
api_thread = threading.Thread(
@@ -206,13 +161,14 @@ async def main():
args=(
command_buffer,
global_events,
lambda: master.state,
),
daemon=True
)
api_thread.start()
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
master = Master(command_buffer, global_events, forwarder_binary_path=Path("forwarder"), logger=logger)
master = Master(node_id, command_buffer, global_events, forwarder_binary_path=Path("./build/forwarder"), logger=logger)
await master.run()
if __name__ == "__main__":

View File

@@ -13,15 +13,15 @@ from shared.topology import Topology
from shared.types.events import Event, InstanceCreated, InstanceDeleted
from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceStatus
@singledispatch
def get_instance_placements(
command: CreateInstanceCommand,
topology: Topology,
current_instances: dict[InstanceId, InstanceParams],
) -> dict[InstanceId, InstanceParams]:
current_instances: dict[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances]
if command.model_meta.model_id in available_models:
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
@@ -36,9 +36,11 @@ def get_instance_placements(
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)
instance_id = InstanceId()
instance_id = command.instance_id
target_instances = deepcopy(current_instances)
target_instances[instance_id] = InstanceParams(
target_instances[instance_id] = Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=[]
)
@@ -46,7 +48,7 @@ def get_instance_placements(
@get_instance_placements.register
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]:
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, Instance]) -> dict[InstanceId, Instance]:
target_instances = deepcopy(current_instances)
if command.instance_id in target_instances:
del target_instances[command.instance_id]
@@ -55,19 +57,17 @@ def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dic
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceParams],
target_instances: Mapping[InstanceId, InstanceParams],
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
) -> Sequence[Event]:
events: list[Event] = []
# find instances to create
for instance_id, instance_params in target_instances.items():
for instance_id, instance in target_instances.items():
if instance_id not in current_instances:
events.append(
InstanceCreated(
instance_id=instance_id,
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE
instance=instance,
)
)

View File

@@ -11,6 +11,7 @@ from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import NodeId
from shared.types.events import TaskCreated
from shared.types.events.commands import ChatCompletionCommand, Command, CommandId
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
@@ -36,7 +37,8 @@ async def test_master():
forwarder_binary_path = _create_forwarder_dummy_binary()
master = Master(command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
master = Master(node_id, command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
asyncio.create_task(master.run())
command_buffer.append(

View File

@@ -12,7 +12,7 @@ from shared.types.events.commands import CreateInstanceCommand
from shared.types.models import ModelMetadata
from shared.types.topology import Connection, Node
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import InstanceParams
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.runners import ShardAssignments
@@ -21,8 +21,10 @@ def topology() -> Topology:
return Topology()
@pytest.fixture
def instance_params() -> InstanceParams:
return InstanceParams(
def instance() -> Instance:
return Instance(
instance_id=InstanceId(),
instance_type=InstanceStatus.ACTIVE,
shard_assignments=ShardAssignments(
model_id="test-model",
runner_to_shard={},
@@ -43,7 +45,8 @@ def model_meta() -> ModelMetadata:
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
return CreateInstanceCommand(
command_id=CommandId(),
model_meta=model_meta
model_meta=model_meta,
instance_id=InstanceId(),
)
@@ -66,7 +69,8 @@ def test_get_instance_placements_create_instance(
create_instance_command = CreateInstanceCommand(
command_id=CommandId(),
model_meta=model_meta
model_meta=model_meta,
instance_id=InstanceId(),
)
node_id_a = NodeId()
node_id_b = NodeId()
@@ -84,16 +88,16 @@ def test_get_instance_placements_create_instance(
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance_params = placements[instance_id]
assert instance_params.shard_assignments.model_id == model_meta.model_id
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b]
runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c]
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
runner_id_c = instance.shard_assignments.node_to_runner[node_id_c]
shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a]
shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b]
shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c]
shard_a = instance.shard_assignments.runner_to_shard[runner_id_a]
shard_b = instance.shard_assignments.runner_to_shard[runner_id_b]
shard_c = instance.shard_assignments.runner_to_shard[runner_id_c]
assert shard_a.end_layer - shard_a.start_layer == expected_layers[0]
assert shard_b.end_layer - shard_b.start_layer == expected_layers[1]
@@ -105,14 +109,14 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams):
def test_get_transition_events_no_change(topology: Topology, instance: Instance):
# arrange
instance_id = InstanceId()
current_instances = {
instance_id: instance_params
instance_id: instance
}
target_instances = {
instance_id: instance_params
instance_id: instance
}
# act
@@ -122,12 +126,12 @@ def test_get_transition_events_no_change(topology: Topology, instance_params: In
assert len(events) == 0
def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams):
def test_get_transition_events_create_instance(topology: Topology, instance: Instance):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, InstanceParams] = {}
target_instances: dict[InstanceId, InstanceParams] = {
instance_id: instance_params
current_instances: dict[InstanceId, Instance] = {}
target_instances: dict[InstanceId, Instance] = {
instance_id: instance
}
# act
@@ -138,13 +142,13 @@ def test_get_transition_events_create_instance(topology: Topology, instance_para
assert events[0].event_type == _EventType.InstanceCreated
def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams):
def test_get_transition_events_delete_instance(topology: Topology, instance: Instance):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, InstanceParams] = {
instance_id: instance_params
current_instances: dict[InstanceId, Instance] = {
instance_id: instance
}
target_instances: dict[InstanceId, InstanceParams] = {}
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances)

View File

@@ -281,7 +281,7 @@ func (c *sqliteConnector) getLatestRowIds() (map[SourceKey]int64, error) {
}
selectCols := strings.Join(keyCols, ", ")
query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" GROUP BY %s`, selectCols, rowIDCol, c.tableName, selectCols)
query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" WHERE %s IS NOT NULL GROUP BY %s`, selectCols, rowIDCol, c.tableName, rowIDCol, selectCols)
rows, err := c.db.Query(query)
if err != nil {

View File

@@ -28,7 +28,7 @@ from shared.types.profiling import NodePerformanceProfile
from shared.types.state import State
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import NodeStatus, RunnerId
from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceId, InstanceStatus
from shared.types.worker.runners import RunnerStatus
S = TypeVar("S", bound=State)
@@ -62,8 +62,8 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
@event_apply.register(InstanceCreated)
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type)
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance}
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {**state.instances, instance.instance_id: instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceActivated)
@@ -71,8 +71,8 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE})
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.ACTIVE})
new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceDeactivated)
@@ -80,13 +80,13 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE})
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.INACTIVE})
new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceDeleted)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
new_instances: Mapping[InstanceId, Instance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceReplacedAtomically)

View File

@@ -3,6 +3,9 @@ from typing import Any, Literal
from pydantic import BaseModel
from shared.openai_compat import FinishReason
from shared.types.common import CommandId
from shared.types.models import ModelMetadata
from shared.types.worker.instances import InstanceId
class ChatCompletionMessage(BaseModel):
@@ -97,8 +100,20 @@ class ChatCompletionTaskParams(BaseModel):
parallel_tool_calls: bool | None = None
user: str | None = None
class RequestInstanceTaskParams(BaseModel):
class CreateInstanceTaskParams(BaseModel):
# TODO: in future the user could specify a specific Instance, not just a model_id
model_id: str
class DeleteInstanceTaskParams(BaseModel):
instance_id: str
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
instance_id: InstanceId
class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId

View File

@@ -18,7 +18,7 @@ from shared.types.common import NodeId
from shared.types.events.chunks import CommandId, GenerationChunk
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.instances import Instance
from shared.types.worker.runners import RunnerId, RunnerStatus
if TYPE_CHECKING:
@@ -114,9 +114,7 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]):
class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]):
event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated
instance_id: InstanceId
instance_params: InstanceParams
instance_type: TypeOfInstance
instance: Instance
class InstanceActivated(_BaseEvent[_EventType.InstanceActivated]):

View File

@@ -7,7 +7,8 @@ from shared.types.api import ChatCompletionTaskParams
from shared.types.common import CommandId
from shared.types.events import Event
from shared.types.models import ModelMetadata
from shared.types.state import InstanceId, State
from shared.types.state import State
from shared.types.worker.common import InstanceId
# TODO: We need to have a distinction between create instance and spin up instance.
@@ -30,6 +31,7 @@ class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]):
class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]):
command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE
model_meta: ModelMetadata
instance_id: InstanceId
class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):

View File

@@ -2,8 +2,8 @@ from pydantic import BaseModel
from shared.types.api import (
ChatCompletionTaskParams,
CreateInstanceTaskParams,
DeleteInstanceTaskParams,
RequestInstanceTaskParams,
)
from shared.types.events import CommandId
@@ -12,12 +12,12 @@ class ChatCompletionCommand(BaseModel):
command_id: CommandId
command_params: ChatCompletionTaskParams
class RequestInstanceCommand(BaseModel):
class CreateInstanceCommand(BaseModel):
command_id: CommandId
command_params: RequestInstanceTaskParams
command_params: CreateInstanceTaskParams
class DeleteInstanceCommand(BaseModel):
command_id: CommandId
command_params: DeleteInstanceTaskParams
type Command = ChatCompletionCommand | RequestInstanceCommand | DeleteInstanceCommand
type Command = ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand

View File

@@ -7,14 +7,14 @@ from shared.types.common import NodeId
from shared.types.profiling import NodePerformanceProfile
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import BaseInstance
from shared.types.worker.instances import Instance
from shared.types.worker.runners import RunnerId, RunnerStatus
class State(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
node_status: Mapping[NodeId, NodeStatus] = {}
instances: Mapping[InstanceId, BaseInstance] = {}
instances: Mapping[InstanceId, Instance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}

View File

@@ -9,20 +9,12 @@ from shared.types.worker.runners import (
)
class TypeOfInstance(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
class InstanceStatus(str, Enum):
ACTIVE = "ACTIVE"
INACTIVE = "INACTIVE"
class InstanceParams(BaseModel):
class Instance(BaseModel):
instance_id: InstanceId
instance_type: InstanceStatus
shard_assignments: ShardAssignments
hosts: list[Host]
class BaseInstance(BaseModel):
instance_params: InstanceParams
instance_type: TypeOfInstance
class Instance(BaseInstance):
instance_id: InstanceId

View File

@@ -28,7 +28,7 @@ from shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgressData,
)
from shared.types.worker.instances import TypeOfInstance
from shared.types.worker.instances import InstanceStatus
from shared.types.worker.mlx import Host
from shared.types.worker.ops import (
AssignRunnerOp,
@@ -323,29 +323,29 @@ class Worker:
runner_ids: list[RunnerId] = [
runner_id
for instance in state.instances.values()
for runner_id in instance.instance_params.shard_assignments.runner_to_shard
for runner_id in instance.shard_assignments.runner_to_shard
]
if runner_id not in runner_ids:
return UnassignRunnerOp(runner_id=runner_id)
# Then spin down active runners
for _instance_id, instance in state.instances.items():
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
# We spin down a runner if it's meant to be inactive and it's Loaded.
if runner_id in self.assigned_runners and \
isinstance(self.assigned_runners[runner_id].status, LoadedRunnerStatus) and \
instance.instance_type == TypeOfInstance.INACTIVE:
instance.instance_type == InstanceStatus.INACTIVE:
return RunnerDownOp(runner_id=runner_id)
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
# TODO: We need to limit number of retries if we keep failing.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.instance_params.shard_assignments.node_to_runner:
if self.node_id in instance.shard_assignments.node_to_runner:
other_node_in_instance_has_failed = False
for runner_id in instance.instance_params.shard_assignments.runner_to_shard:
for runner_id in instance.shard_assignments.runner_to_shard:
if runner_id in state.runners and \
isinstance(state.runners[runner_id], FailedRunnerStatus) and \
runner_id not in self.assigned_runners:
@@ -353,28 +353,28 @@ class Worker:
if other_node_in_instance_has_failed:
# Spin down *our* runner
return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# If we are failed - and *all of the other nodes have spun down* - then we can spin down too.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \
instance.instance_params.shard_assignments.node_to_runner[self.node_id] in state.runners and \
isinstance(state.runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
if self.node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \
isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
num_spundown_nodes = 0
for runner_id in instance.instance_params.shard_assignments.runner_to_shard:
for runner_id in instance.shard_assignments.runner_to_shard:
if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \
runner_id not in self.assigned_runners:
num_spundown_nodes += 1
if num_spundown_nodes == next(iter(instance.instance_params.shard_assignments.runner_to_shard.values())).world_size - 1:
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
# All the other nodes are spun down - so now we can spin down too.
# This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away
return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# Then assign runners we do want
for instance_id, instance in state.instances.items():
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
@@ -382,15 +382,15 @@ class Worker:
return AssignRunnerOp(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id],
hosts=instance.instance_params.hosts
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
hosts=instance.hosts
)
# Then make sure things are downloading.
for instance_id, instance in state.instances.items():
# We should already have asserted that this runner exists
# If it didn't exist then we return a assign_runner op.
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
assert runner_id in self.assigned_runners
@@ -404,29 +404,29 @@ class Worker:
return DownloadOp(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id],
hosts=instance.instance_params.hosts
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
hosts=instance.hosts
)
# Then spin up 'ready' runners that should be active
for _instance_id, instance in state.instances.items():
if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \
self.assigned_runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]].runner is None and \
instance.instance_type == TypeOfInstance.ACTIVE:
if self.node_id in instance.shard_assignments.node_to_runner and \
self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].runner is None and \
instance.instance_type == InstanceStatus.ACTIVE:
# We are part of this instance, we want it up but it hasn't been spun up yet.
# Need to assert all other runners are ready before we can spin up.
ready_to_spin = True
for runner_id in instance.instance_params.shard_assignments.node_to_runner.values():
for runner_id in instance.shard_assignments.node_to_runner.values():
if state.runners[runner_id].runner_status != RunnerStatusType.Ready:
ready_to_spin = False
if ready_to_spin:
return RunnerUpOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# Then make sure things are running based on tasks.
for instance_id, instance in state.instances.items():
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
assert runner_id in self.assigned_runners
@@ -443,7 +443,7 @@ class Worker:
# so let's check that all the other runners are running - ready for us to fire the prompt.
running_runner_count = 0
for other_runner_id, other_runner_status in state.runners.items():
if other_runner_id in instance.instance_params.shard_assignments.node_to_runner.values() and \
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
isinstance(other_runner_status, RunningRunnerStatus):
running_runner_count += 1

View File

@@ -19,7 +19,7 @@ from shared.types.tasks import (
TaskType,
)
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.mlx import Host
from shared.types.worker.ops import (
AssignRunnerOp,
@@ -140,15 +140,11 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h
node_to_runner={node_id: runner_id}
)
instance_params = InstanceParams(
shard_assignments=shard_assignments,
hosts=hosts_one
)
return Instance(
instance_id=InstanceId(),
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts_one
)
return _instance
@@ -166,13 +162,13 @@ async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId
instance_obj: Instance = instance(worker.node_id, RunnerId())
# Extract runner_id from shard assignments
runner_id = next(iter(instance_obj.instance_params.shard_assignments.runner_to_shard))
runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard))
# Assign the runner
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.instance_params.hosts,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)

View File

@@ -46,8 +46,8 @@ async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId],
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.instance_params.hosts,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
@@ -138,8 +138,8 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
download_op = DownloadOp(
instance_id=instance_obj.instance_id,
runner_id=runner_id,
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.instance_params.hosts,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
)
events: list[Event] = []

View File

@@ -15,7 +15,7 @@ from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.runners import (
LoadedRunnerStatus,
ReadyRunnerStatus,
@@ -50,14 +50,12 @@ async def test_runner_assigned(
print(worker)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.INACTIVE
instance_value.instance_type = InstanceStatus.INACTIVE
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
instance=instance_value
)
],
origin=MASTER_NODE_ID
@@ -87,14 +85,12 @@ async def test_runner_assigned_active(
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.ACTIVE
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
instance=instance_value
)
],
origin=MASTER_NODE_ID
@@ -141,9 +137,7 @@ async def test_runner_assigned_wrong_node(
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
instance=instance_value
)
],
origin=MASTER_NODE_ID
@@ -168,14 +162,12 @@ async def test_runner_unassigns(
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value.instance_type = TypeOfInstance.ACTIVE
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance_id=instance_value.instance_id,
instance_params=instance_value.instance_params,
instance_type=instance_value.instance_type
instance=instance_value
)
],
origin=MASTER_NODE_ID

View File

@@ -16,7 +16,7 @@ from shared.types.tasks import (
)
from shared.types.worker.common import NodeStatus
from shared.types.worker.downloads import DownloadPending
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
@@ -90,9 +90,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -101,7 +100,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
@@ -124,9 +122,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -135,7 +132,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
@@ -158,9 +154,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -169,7 +164,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus()},
@@ -184,9 +178,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE, # Either active or inactive should yield the same.
instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -195,7 +188,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: AssignedRunnerStatus()},
@@ -245,9 +237,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -256,7 +247,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: AssignedRunnerStatus()},
@@ -291,9 +281,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -302,7 +291,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus()},
@@ -337,9 +325,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -349,7 +336,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))},
@@ -382,9 +368,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -394,7 +379,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
@@ -418,9 +402,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -429,7 +412,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus()},
@@ -453,9 +435,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.INACTIVE,
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -464,7 +445,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: FailedRunnerStatus()},
@@ -488,9 +468,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -499,7 +478,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus()},
@@ -542,9 +520,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -554,7 +531,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
@@ -587,9 +563,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -599,7 +574,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
@@ -644,9 +618,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -656,7 +629,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()},
@@ -701,9 +673,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -713,7 +684,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()},
@@ -737,9 +707,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -748,7 +717,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: FailedRunnerStatus()},
@@ -781,9 +749,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
@@ -793,7 +760,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
),
)
},
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
@@ -825,19 +791,17 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
@@ -884,7 +848,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
if len(case.state.instances) == 1:
instance_id = next(iter(case.state.instances))
shard_assignments = case.state.instances[instance_id].instance_params.shard_assignments
shard_assignments = case.state.instances[instance_id].shard_assignments
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
# Only add this runner if it belongs to our node

View File

@@ -11,7 +11,7 @@ from shared.types.state import State
from shared.types.tasks import TaskId
from shared.types.worker.common import InstanceId, NodeStatus, RunnerId
from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import RunnerOp
from shared.types.worker.runners import (
AssignedRunnerStatus,
@@ -148,14 +148,11 @@ def create_worker_state(
runner_to_shard={runner_id: shard_metadata},
node_to_runner={node_id: runner_id},
)
instance_params = InstanceParams(
shard_assignments=shard_assignments,
hosts=[],
)
instance = Instance(
instance_id=instance_id,
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=[],
)
instances[instance_id] = instance
@@ -198,14 +195,11 @@ def make_instance(
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
instance_params = InstanceParams(
return Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=[],
)
return Instance(
instance_id=instance_id,
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE,
)
### For worker plan tests