mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 14:17:58 -05:00
Glue
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Sequence, final
|
||||
from typing import Callable, List, Sequence, final
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.api import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceResponse,
|
||||
CreateInstanceTaskParams,
|
||||
DeleteInstanceResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from shared.types.common import CommandId
|
||||
@@ -20,9 +25,14 @@ from shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
Command,
|
||||
CommandType,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
)
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import ChatCompletionTaskParams
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import Instance
|
||||
|
||||
|
||||
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
||||
@@ -45,20 +55,21 @@ def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
||||
|
||||
@final
|
||||
class API:
|
||||
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage) -> None:
|
||||
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, get_state: Callable[[], State]) -> None:
|
||||
self._app = FastAPI()
|
||||
self._setup_routes()
|
||||
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.get_state = get_state
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
|
||||
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
|
||||
# self._app.get("/instances/list")(self.list_instances)
|
||||
# self._app.post("/instances/create")(self.create_instance)
|
||||
# self._app.get("/instance/{instance_id}/read")(self.get_instance)
|
||||
# self._app.delete("/instance/{instance_id}/delete")(self.remove_instance)
|
||||
self._app.post("/instances/create")(self.create_instance)
|
||||
self._app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self._app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
|
||||
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
|
||||
self._app.post("/v1/chat/completions")(self.chat_completions)
|
||||
@@ -80,11 +91,49 @@ class API:
|
||||
# def list_instances(self):
|
||||
# return {"message": "Hello, World!"}
|
||||
|
||||
# def create_instance(self, model_id: ModelId) -> InstanceId: ...
|
||||
async def create_instance(self, payload: CreateInstanceTaskParams) -> CreateInstanceResponse:
|
||||
if payload.model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[payload.model_id]
|
||||
model_meta = model_card.metadata
|
||||
else:
|
||||
model_meta = await get_model_meta(payload.model_id)
|
||||
|
||||
# def get_instance(self, instance_id: InstanceId) -> Instance: ...
|
||||
command = CreateInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
command_type=CommandType.CREATE_INSTANCE,
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
|
||||
# def remove_instance(self, instance_id: InstanceId) -> None: ...
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
instance_id=command.instance_id,
|
||||
)
|
||||
|
||||
def get_instance(self, instance_id: InstanceId) -> Instance:
|
||||
state = self.get_state()
|
||||
if instance_id not in state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
return state.instances[instance_id]
|
||||
|
||||
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
|
||||
if instance_id not in self.get_state().instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = DeleteInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
command_type=CommandType.DELETE_INSTANCE,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
return DeleteInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
|
||||
|
||||
@@ -140,9 +189,10 @@ class API:
|
||||
def start_fastapi_server(
|
||||
command_buffer: List[Command],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
get_state: Callable[[], State],
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
):
|
||||
api = API(command_buffer, global_events)
|
||||
api = API(command_buffer, global_events, get_state)
|
||||
|
||||
uvicorn.run(api.app, host=host, port=port)
|
||||
@@ -106,8 +106,8 @@ class ForwarderSupervisor:
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
str(self._binary_path),
|
||||
f'{pairs}',
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
stdout=None,
|
||||
stderr=None,
|
||||
)
|
||||
|
||||
self._logger.info(f"Starting forwarder with forwarding pairs: {pairs}")
|
||||
|
||||
238
master/main.py
238
master/main.py
@@ -7,92 +7,40 @@ from typing import List
|
||||
|
||||
from master.api import start_fastapi_server
|
||||
from master.election_callback import ElectionCallbacks
|
||||
from master.forwarder_supervisor import ForwarderSupervisor
|
||||
from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
|
||||
from master.placement import get_instance_placements, get_transition_events
|
||||
from shared.apply import apply
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.common import CommandId, NodeId
|
||||
from shared.node_id import get_node_id_keypair
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
InstanceCreated,
|
||||
Event,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
Command,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
)
|
||||
from shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import (
|
||||
InstanceParams,
|
||||
ShardAssignments,
|
||||
TypeOfInstance,
|
||||
)
|
||||
from shared.types.worker.runners import RunnerId
|
||||
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
|
||||
from shared.types.worker.instances import Instance
|
||||
|
||||
|
||||
## TODO: Hook this up properly
|
||||
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: CommandId):
|
||||
model_id = "testmodelabc"
|
||||
|
||||
for i in range(10):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Create the event with proper types and consistent IDs
|
||||
chunk_event = ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
command_id=command_id, # Use the same task_id
|
||||
idx=i,
|
||||
model=model_id, # Use the same model_id
|
||||
text=f'text{i}',
|
||||
token_id=i
|
||||
)
|
||||
)
|
||||
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
[chunk_event],
|
||||
origin=NodeId()
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Create the event with proper types and consistent IDs
|
||||
chunk_event = ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
command_id=command_id, # Use the same task_id
|
||||
idx=11,
|
||||
model=model_id, # Use the same model_id
|
||||
text=f'text{11}',
|
||||
token_id=11,
|
||||
finish_reason='stop'
|
||||
)
|
||||
)
|
||||
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
[chunk_event],
|
||||
origin=NodeId()
|
||||
)
|
||||
|
||||
def get_node_id() -> NodeId:
|
||||
return NodeId() # TODO
|
||||
|
||||
class Master:
|
||||
def __init__(self, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
|
||||
def __init__(self, node_id: NodeId, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
|
||||
self.node_id = node_id
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.node_id = get_node_id()
|
||||
self.forwarder_supervisor = ForwarderSupervisor(
|
||||
forwarder_binary_path=forwarder_binary_path,
|
||||
logger=logger
|
||||
@@ -104,6 +52,62 @@ class Master:
|
||||
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
|
||||
return State()
|
||||
|
||||
async def _run_event_loop_body(self) -> None:
|
||||
if self.forwarder_supervisor.current_role == ForwarderRole.REPLICA:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
next_events: list[Event] = []
|
||||
# 1. process commands
|
||||
if len(self.command_buffer) > 0:
|
||||
# for now we do one command at a time
|
||||
next_command = self.command_buffer.pop(0)
|
||||
self.logger.info(f"got command: {next_command}")
|
||||
# TODO: validate the command
|
||||
match next_command:
|
||||
case ChatCompletionCommand():
|
||||
matching_instance: Instance | None = None
|
||||
for instance in self.state.instances.values():
|
||||
if instance.shard_assignments.model_id == next_command.request_params.model:
|
||||
matching_instance = instance
|
||||
break
|
||||
if not matching_instance:
|
||||
raise ValueError(f"No instance found for model {next_command.request_params.model}")
|
||||
|
||||
task_id = TaskId()
|
||||
next_events.append(TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
instance_id=matching_instance.instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=next_command.request_params
|
||||
)
|
||||
))
|
||||
case DeleteInstanceCommand():
|
||||
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
|
||||
transition_events = get_transition_events(self.state.instances, placement)
|
||||
next_events.extend(transition_events)
|
||||
case CreateInstanceCommand():
|
||||
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
|
||||
transition_events = get_transition_events(self.state.instances, placement)
|
||||
next_events.extend(transition_events)
|
||||
|
||||
await self.global_events.append_events(next_events, origin=self.node_id)
|
||||
|
||||
# 2. get latest events
|
||||
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
return
|
||||
|
||||
# 3. for each event, apply it to the state
|
||||
for event_from_log in events:
|
||||
self.state = apply(self.state, event_from_log)
|
||||
|
||||
self.logger.info(f"state: {self.state.model_dump_json()}")
|
||||
|
||||
async def run(self):
|
||||
self.state = await self._get_state_snapshot()
|
||||
|
||||
@@ -115,90 +119,41 @@ class Master:
|
||||
await self.election_callbacks.on_became_master()
|
||||
|
||||
while True:
|
||||
next_event = None
|
||||
# 1. process commands
|
||||
if len(self.command_buffer) > 0:
|
||||
# for now we do one command at a time
|
||||
next_command = self.command_buffer.pop(0)
|
||||
self.logger.info(f"got command: {next_command}")
|
||||
# TODO: validate the command
|
||||
match next_command:
|
||||
case ChatCompletionCommand():
|
||||
# 1. find a valid instance for this request, if none exists ERROR (TODO)
|
||||
instance_id = InstanceId()
|
||||
task_id = TaskId()
|
||||
# 2. publish TaskCreated event (TODO)
|
||||
next_event = TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
instance_id=instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=next_command.request_params
|
||||
)
|
||||
)
|
||||
case DeleteInstanceCommand():
|
||||
# TODO
|
||||
pass
|
||||
case CreateInstanceCommand():
|
||||
if next_command.model_meta.model_id not in MODEL_CARDS:
|
||||
raise ValueError(f"Model {next_command.model_meta.model_id} not supported.")
|
||||
|
||||
# TODO: we should also support models that aren't in MODEL_CARDS
|
||||
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
|
||||
if next_command.model_meta.model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[next_command.model_meta.model_id]
|
||||
model_meta = model_card.metadata
|
||||
else:
|
||||
model_meta = await get_model_meta(next_command.model_meta.model_id)
|
||||
|
||||
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
|
||||
next_event = InstanceCreated(
|
||||
instance_id=InstanceId(),
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=next_command.model_meta.model_id,
|
||||
runner_to_shard={
|
||||
RunnerId(): PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=0,
|
||||
n_layers=0
|
||||
)
|
||||
},
|
||||
node_to_runner={}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
)
|
||||
|
||||
if next_event is not None:
|
||||
await self.global_events.append_events([next_event], origin=self.node_id)
|
||||
|
||||
# 2. get latest events
|
||||
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 3. for each event, apply it to the state
|
||||
for event_from_log in events:
|
||||
self.state = apply(self.state, event_from_log)
|
||||
try:
|
||||
await self._run_event_loop_body()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in _run_event_loop_body: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
logger = Logger(name='master_logger')
|
||||
node_id_keypair = get_node_id_keypair()
|
||||
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
|
||||
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
|
||||
# TODO: this should be the resource monitor that does this
|
||||
await global_events.append_events([NodePerformanceMeasured(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="testmodelabc",
|
||||
chip_id="testchipabc",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=1000,
|
||||
ram_available=1000,
|
||||
swap_total=1000,
|
||||
swap_available=1000
|
||||
),
|
||||
system=SystemPerformanceProfile(
|
||||
flops_fp16=1000
|
||||
)
|
||||
)
|
||||
)], origin=node_id)
|
||||
|
||||
command_buffer: List[Command] = []
|
||||
|
||||
api_thread = threading.Thread(
|
||||
@@ -206,13 +161,14 @@ async def main():
|
||||
args=(
|
||||
command_buffer,
|
||||
global_events,
|
||||
lambda: master.state,
|
||||
),
|
||||
daemon=True
|
||||
)
|
||||
api_thread.start()
|
||||
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
|
||||
|
||||
master = Master(command_buffer, global_events, forwarder_binary_path=Path("forwarder"), logger=logger)
|
||||
master = Master(node_id, command_buffer, global_events, forwarder_binary_path=Path("./build/forwarder"), logger=logger)
|
||||
await master.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -13,15 +13,15 @@ from shared.topology import Topology
|
||||
from shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
|
||||
|
||||
@singledispatch
|
||||
def get_instance_placements(
|
||||
command: CreateInstanceCommand,
|
||||
topology: Topology,
|
||||
current_instances: dict[InstanceId, InstanceParams],
|
||||
) -> dict[InstanceId, InstanceParams]:
|
||||
current_instances: dict[InstanceId, Instance],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances]
|
||||
if command.model_meta.model_id in available_models:
|
||||
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
|
||||
@@ -36,9 +36,11 @@ def get_instance_placements(
|
||||
|
||||
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
instance_id = command.instance_id
|
||||
target_instances = deepcopy(current_instances)
|
||||
target_instances[instance_id] = InstanceParams(
|
||||
target_instances[instance_id] = Instance(
|
||||
instance_id=instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[]
|
||||
)
|
||||
@@ -46,7 +48,7 @@ def get_instance_placements(
|
||||
|
||||
|
||||
@get_instance_placements.register
|
||||
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]:
|
||||
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, Instance]) -> dict[InstanceId, Instance]:
|
||||
target_instances = deepcopy(current_instances)
|
||||
if command.instance_id in target_instances:
|
||||
del target_instances[command.instance_id]
|
||||
@@ -55,19 +57,17 @@ def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dic
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
target_instances: Mapping[InstanceId, InstanceParams],
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
# find instances to create
|
||||
for instance_id, instance_params in target_instances.items():
|
||||
for instance_id, instance in target_instances.items():
|
||||
if instance_id not in current_instances:
|
||||
events.append(
|
||||
InstanceCreated(
|
||||
instance_id=instance_id,
|
||||
instance_params=instance_params,
|
||||
instance_type=TypeOfInstance.ACTIVE
|
||||
instance=instance,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import TaskCreated
|
||||
from shared.types.events.commands import ChatCompletionCommand, Command, CommandId
|
||||
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
|
||||
@@ -36,7 +37,8 @@ async def test_master():
|
||||
|
||||
forwarder_binary_path = _create_forwarder_dummy_binary()
|
||||
|
||||
master = Master(command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
|
||||
node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
master = Master(node_id, command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
|
||||
asyncio.create_task(master.run())
|
||||
|
||||
command_buffer.append(
|
||||
|
||||
@@ -12,7 +12,7 @@ from shared.types.events.commands import CreateInstanceCommand
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.topology import Connection, Node
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import InstanceParams
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.runners import ShardAssignments
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
@pytest.fixture
|
||||
def instance_params() -> InstanceParams:
|
||||
return InstanceParams(
|
||||
def instance() -> Instance:
|
||||
return Instance(
|
||||
instance_id=InstanceId(),
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id="test-model",
|
||||
runner_to_shard={},
|
||||
@@ -43,7 +45,8 @@ def model_meta() -> ModelMetadata:
|
||||
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
|
||||
return CreateInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +69,8 @@ def test_get_instance_placements_create_instance(
|
||||
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
@@ -84,16 +88,16 @@ def test_get_instance_placements_create_instance(
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance_params = placements[instance_id]
|
||||
assert instance_params.shard_assignments.model_id == model_meta.model_id
|
||||
instance = placements[instance_id]
|
||||
assert instance.shard_assignments.model_id == model_meta.model_id
|
||||
|
||||
runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a]
|
||||
runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b]
|
||||
runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c]
|
||||
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
|
||||
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
|
||||
runner_id_c = instance.shard_assignments.node_to_runner[node_id_c]
|
||||
|
||||
shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a]
|
||||
shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b]
|
||||
shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c]
|
||||
shard_a = instance.shard_assignments.runner_to_shard[runner_id_a]
|
||||
shard_b = instance.shard_assignments.runner_to_shard[runner_id_b]
|
||||
shard_c = instance.shard_assignments.runner_to_shard[runner_id_c]
|
||||
|
||||
assert shard_a.end_layer - shard_a.start_layer == expected_layers[0]
|
||||
assert shard_b.end_layer - shard_b.start_layer == expected_layers[1]
|
||||
@@ -105,14 +109,14 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams):
|
||||
def test_get_transition_events_no_change(topology: Topology, instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances = {
|
||||
instance_id: instance_params
|
||||
instance_id: instance
|
||||
}
|
||||
target_instances = {
|
||||
instance_id: instance_params
|
||||
instance_id: instance
|
||||
}
|
||||
|
||||
# act
|
||||
@@ -122,12 +126,12 @@ def test_get_transition_events_no_change(topology: Topology, instance_params: In
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams):
|
||||
def test_get_transition_events_create_instance(topology: Topology, instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, InstanceParams] = {}
|
||||
target_instances: dict[InstanceId, InstanceParams] = {
|
||||
instance_id: instance_params
|
||||
current_instances: dict[InstanceId, Instance] = {}
|
||||
target_instances: dict[InstanceId, Instance] = {
|
||||
instance_id: instance
|
||||
}
|
||||
|
||||
# act
|
||||
@@ -138,13 +142,13 @@ def test_get_transition_events_create_instance(topology: Topology, instance_para
|
||||
assert events[0].event_type == _EventType.InstanceCreated
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams):
|
||||
def test_get_transition_events_delete_instance(topology: Topology, instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, InstanceParams] = {
|
||||
instance_id: instance_params
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id: instance
|
||||
}
|
||||
target_instances: dict[InstanceId, InstanceParams] = {}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
@@ -281,7 +281,7 @@ func (c *sqliteConnector) getLatestRowIds() (map[SourceKey]int64, error) {
|
||||
}
|
||||
|
||||
selectCols := strings.Join(keyCols, ", ")
|
||||
query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" GROUP BY %s`, selectCols, rowIDCol, c.tableName, selectCols)
|
||||
query := fmt.Sprintf(`SELECT %s, MAX(%s) FROM "%s" WHERE %s IS NOT NULL GROUP BY %s`, selectCols, rowIDCol, c.tableName, rowIDCol, selectCols)
|
||||
|
||||
rows, err := c.db.Query(query)
|
||||
if err != nil {
|
||||
|
||||
@@ -28,7 +28,7 @@ from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import NodeStatus, RunnerId
|
||||
from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceId, InstanceStatus
|
||||
from shared.types.worker.runners import RunnerStatus
|
||||
|
||||
S = TypeVar("S", bound=State)
|
||||
@@ -62,8 +62,8 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
|
||||
|
||||
@event_apply.register(InstanceCreated)
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type)
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance}
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {**state.instances, instance.instance_id: instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register(InstanceActivated)
|
||||
@@ -71,8 +71,8 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE})
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.ACTIVE})
|
||||
new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register(InstanceDeactivated)
|
||||
@@ -80,13 +80,13 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE})
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": InstanceStatus.INACTIVE})
|
||||
new_instances: Mapping[InstanceId, Instance] = {**state.instances, event.instance_id: updated_instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register(InstanceDeleted)
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
|
||||
new_instances: Mapping[InstanceId, Instance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register(InstanceReplacedAtomically)
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.worker.instances import InstanceId
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
@@ -97,8 +100,20 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
|
||||
class RequestInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceTaskParams(BaseModel):
|
||||
# TODO: in future the user could specify a specific Instance, not just a model_id
|
||||
model_id: str
|
||||
|
||||
class DeleteInstanceTaskParams(BaseModel):
|
||||
instance_id: str
|
||||
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
instance_id: InstanceId
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -18,7 +18,7 @@ from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import CommandId, GenerationChunk
|
||||
from shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -114,9 +114,7 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]):
|
||||
|
||||
class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]):
|
||||
event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated
|
||||
instance_id: InstanceId
|
||||
instance_params: InstanceParams
|
||||
instance_type: TypeOfInstance
|
||||
instance: Instance
|
||||
|
||||
|
||||
class InstanceActivated(_BaseEvent[_EventType.InstanceActivated]):
|
||||
|
||||
@@ -7,7 +7,8 @@ from shared.types.api import ChatCompletionTaskParams
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.events import Event
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.state import InstanceId, State
|
||||
from shared.types.state import State
|
||||
from shared.types.worker.common import InstanceId
|
||||
|
||||
|
||||
# TODO: We need to have a distinction between create instance and spin up instance.
|
||||
@@ -30,6 +31,7 @@ class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]):
|
||||
class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]):
|
||||
command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE
|
||||
model_meta: ModelMetadata
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):
|
||||
|
||||
@@ -2,8 +2,8 @@ from pydantic import BaseModel
|
||||
|
||||
from shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
CreateInstanceTaskParams,
|
||||
DeleteInstanceTaskParams,
|
||||
RequestInstanceTaskParams,
|
||||
)
|
||||
from shared.types.events import CommandId
|
||||
|
||||
@@ -12,12 +12,12 @@ class ChatCompletionCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: ChatCompletionTaskParams
|
||||
|
||||
class RequestInstanceCommand(BaseModel):
|
||||
class CreateInstanceCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: RequestInstanceTaskParams
|
||||
command_params: CreateInstanceTaskParams
|
||||
|
||||
class DeleteInstanceCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: DeleteInstanceTaskParams
|
||||
|
||||
type Command = ChatCompletionCommand | RequestInstanceCommand | DeleteInstanceCommand
|
||||
type Command = ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand
|
||||
|
||||
@@ -7,14 +7,14 @@ from shared.types.common import NodeId
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import BaseInstance
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
node_status: Mapping[NodeId, NodeStatus] = {}
|
||||
instances: Mapping[InstanceId, BaseInstance] = {}
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
|
||||
|
||||
@@ -9,20 +9,12 @@ from shared.types.worker.runners import (
|
||||
)
|
||||
|
||||
|
||||
class TypeOfInstance(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
class InstanceStatus(str, Enum):
|
||||
ACTIVE = "ACTIVE"
|
||||
INACTIVE = "INACTIVE"
|
||||
|
||||
|
||||
class InstanceParams(BaseModel):
|
||||
class Instance(BaseModel):
|
||||
instance_id: InstanceId
|
||||
instance_type: InstanceStatus
|
||||
shard_assignments: ShardAssignments
|
||||
hosts: list[Host]
|
||||
|
||||
|
||||
class BaseInstance(BaseModel):
|
||||
instance_params: InstanceParams
|
||||
instance_type: TypeOfInstance
|
||||
|
||||
|
||||
class Instance(BaseInstance):
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -28,7 +28,7 @@ from shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgressData,
|
||||
)
|
||||
from shared.types.worker.instances import TypeOfInstance
|
||||
from shared.types.worker.instances import InstanceStatus
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
@@ -323,29 +323,29 @@ class Worker:
|
||||
runner_ids: list[RunnerId] = [
|
||||
runner_id
|
||||
for instance in state.instances.values()
|
||||
for runner_id in instance.instance_params.shard_assignments.runner_to_shard
|
||||
for runner_id in instance.shard_assignments.runner_to_shard
|
||||
]
|
||||
if runner_id not in runner_ids:
|
||||
return UnassignRunnerOp(runner_id=runner_id)
|
||||
|
||||
# Then spin down active runners
|
||||
for _instance_id, instance in state.instances.items():
|
||||
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
if node_id != self.node_id:
|
||||
continue
|
||||
|
||||
# We spin down a runner if it's meant to be inactive and it's Loaded.
|
||||
if runner_id in self.assigned_runners and \
|
||||
isinstance(self.assigned_runners[runner_id].status, LoadedRunnerStatus) and \
|
||||
instance.instance_type == TypeOfInstance.INACTIVE:
|
||||
instance.instance_type == InstanceStatus.INACTIVE:
|
||||
return RunnerDownOp(runner_id=runner_id)
|
||||
|
||||
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
|
||||
# TODO: We need to limit number of retries if we keep failing.
|
||||
for _instance_id, instance in state.instances.items():
|
||||
if self.node_id in instance.instance_params.shard_assignments.node_to_runner:
|
||||
if self.node_id in instance.shard_assignments.node_to_runner:
|
||||
other_node_in_instance_has_failed = False
|
||||
for runner_id in instance.instance_params.shard_assignments.runner_to_shard:
|
||||
for runner_id in instance.shard_assignments.runner_to_shard:
|
||||
if runner_id in state.runners and \
|
||||
isinstance(state.runners[runner_id], FailedRunnerStatus) and \
|
||||
runner_id not in self.assigned_runners:
|
||||
@@ -353,28 +353,28 @@ class Worker:
|
||||
|
||||
if other_node_in_instance_has_failed:
|
||||
# Spin down *our* runner
|
||||
return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
|
||||
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
|
||||
|
||||
# If we are failed - and *all of the other nodes have spun down* - then we can spin down too.
|
||||
for _instance_id, instance in state.instances.items():
|
||||
if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \
|
||||
instance.instance_params.shard_assignments.node_to_runner[self.node_id] in state.runners and \
|
||||
isinstance(state.runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
|
||||
if self.node_id in instance.shard_assignments.node_to_runner and \
|
||||
instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \
|
||||
isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
|
||||
|
||||
num_spundown_nodes = 0
|
||||
for runner_id in instance.instance_params.shard_assignments.runner_to_shard:
|
||||
for runner_id in instance.shard_assignments.runner_to_shard:
|
||||
if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \
|
||||
runner_id not in self.assigned_runners:
|
||||
num_spundown_nodes += 1
|
||||
|
||||
if num_spundown_nodes == next(iter(instance.instance_params.shard_assignments.runner_to_shard.values())).world_size - 1:
|
||||
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
|
||||
# All the other nodes are spun down - so now we can spin down too.
|
||||
# This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away
|
||||
return RunnerDownOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
|
||||
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
|
||||
|
||||
# Then assign runners we do want
|
||||
for instance_id, instance in state.instances.items():
|
||||
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
if node_id != self.node_id:
|
||||
continue
|
||||
|
||||
@@ -382,15 +382,15 @@ class Worker:
|
||||
return AssignRunnerOp(
|
||||
runner_id=runner_id,
|
||||
instance_id=instance_id,
|
||||
shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance.instance_params.hosts
|
||||
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance.hosts
|
||||
)
|
||||
|
||||
# Then make sure things are downloading.
|
||||
for instance_id, instance in state.instances.items():
|
||||
# We should already have asserted that this runner exists
|
||||
# If it didn't exist then we return a assign_runner op.
|
||||
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
if node_id != self.node_id:
|
||||
continue
|
||||
assert runner_id in self.assigned_runners
|
||||
@@ -404,29 +404,29 @@ class Worker:
|
||||
return DownloadOp(
|
||||
runner_id=runner_id,
|
||||
instance_id=instance_id,
|
||||
shard_metadata=instance.instance_params.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance.instance_params.hosts
|
||||
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance.hosts
|
||||
)
|
||||
|
||||
# Then spin up 'ready' runners that should be active
|
||||
for _instance_id, instance in state.instances.items():
|
||||
if self.node_id in instance.instance_params.shard_assignments.node_to_runner and \
|
||||
self.assigned_runners[instance.instance_params.shard_assignments.node_to_runner[self.node_id]].runner is None and \
|
||||
instance.instance_type == TypeOfInstance.ACTIVE:
|
||||
if self.node_id in instance.shard_assignments.node_to_runner and \
|
||||
self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].runner is None and \
|
||||
instance.instance_type == InstanceStatus.ACTIVE:
|
||||
|
||||
# We are part of this instance, we want it up but it hasn't been spun up yet.
|
||||
# Need to assert all other runners are ready before we can spin up.
|
||||
ready_to_spin = True
|
||||
for runner_id in instance.instance_params.shard_assignments.node_to_runner.values():
|
||||
for runner_id in instance.shard_assignments.node_to_runner.values():
|
||||
if state.runners[runner_id].runner_status != RunnerStatusType.Ready:
|
||||
ready_to_spin = False
|
||||
|
||||
if ready_to_spin:
|
||||
return RunnerUpOp(runner_id=instance.instance_params.shard_assignments.node_to_runner[self.node_id])
|
||||
return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
|
||||
|
||||
# Then make sure things are running based on tasks.
|
||||
for instance_id, instance in state.instances.items():
|
||||
for node_id, runner_id in instance.instance_params.shard_assignments.node_to_runner.items():
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
if node_id != self.node_id:
|
||||
continue
|
||||
assert runner_id in self.assigned_runners
|
||||
@@ -443,7 +443,7 @@ class Worker:
|
||||
# so let's check that all the other runners are running - ready for us to fire the prompt.
|
||||
running_runner_count = 0
|
||||
for other_runner_id, other_runner_status in state.runners.items():
|
||||
if other_runner_id in instance.instance_params.shard_assignments.node_to_runner.values() and \
|
||||
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
|
||||
isinstance(other_runner_status, RunningRunnerStatus):
|
||||
running_runner_count += 1
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from shared.types.tasks import (
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
@@ -140,15 +140,11 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h
|
||||
node_to_runner={node_id: runner_id}
|
||||
)
|
||||
|
||||
instance_params = InstanceParams(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts_one
|
||||
)
|
||||
|
||||
return Instance(
|
||||
instance_id=InstanceId(),
|
||||
instance_params=instance_params,
|
||||
instance_type=TypeOfInstance.ACTIVE
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts_one
|
||||
)
|
||||
return _instance
|
||||
|
||||
@@ -166,13 +162,13 @@ async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId
|
||||
instance_obj: Instance = instance(worker.node_id, RunnerId())
|
||||
|
||||
# Extract runner_id from shard assignments
|
||||
runner_id = next(iter(instance_obj.instance_params.shard_assignments.runner_to_shard))
|
||||
runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard))
|
||||
|
||||
# Assign the runner
|
||||
assign_op = AssignRunnerOp(
|
||||
runner_id=runner_id,
|
||||
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.instance_params.hosts,
|
||||
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.hosts,
|
||||
instance_id=instance_obj.instance_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@ async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId],
|
||||
|
||||
assign_op = AssignRunnerOp(
|
||||
runner_id=runner_id,
|
||||
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.instance_params.hosts,
|
||||
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.hosts,
|
||||
instance_id=instance_obj.instance_id,
|
||||
)
|
||||
|
||||
@@ -138,8 +138,8 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
download_op = DownloadOp(
|
||||
instance_id=instance_obj.instance_id,
|
||||
runner_id=runner_id,
|
||||
shard_metadata=instance_obj.instance_params.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.instance_params.hosts,
|
||||
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
|
||||
hosts=instance_obj.hosts,
|
||||
)
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -15,7 +15,7 @@ from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import Instance, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.runners import (
|
||||
LoadedRunnerStatus,
|
||||
ReadyRunnerStatus,
|
||||
@@ -50,14 +50,12 @@ async def test_runner_assigned(
|
||||
print(worker)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = TypeOfInstance.INACTIVE
|
||||
instance_value.instance_type = InstanceStatus.INACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance_id=instance_value.instance_id,
|
||||
instance_params=instance_value.instance_params,
|
||||
instance_type=instance_value.instance_type
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
@@ -87,14 +85,12 @@ async def test_runner_assigned_active(
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = TypeOfInstance.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance_id=instance_value.instance_id,
|
||||
instance_params=instance_value.instance_params,
|
||||
instance_type=instance_value.instance_type
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
@@ -141,9 +137,7 @@ async def test_runner_assigned_wrong_node(
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance_id=instance_value.instance_id,
|
||||
instance_params=instance_value.instance_params,
|
||||
instance_type=instance_value.instance_type
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
@@ -168,14 +162,12 @@ async def test_runner_unassigns(
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = TypeOfInstance.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance_id=instance_value.instance_id,
|
||||
instance_params=instance_value.instance_params,
|
||||
instance_type=instance_value.instance_type
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
|
||||
@@ -16,7 +16,7 @@ from shared.types.tasks import (
|
||||
)
|
||||
from shared.types.worker.common import NodeStatus
|
||||
from shared.types.worker.downloads import DownloadPending
|
||||
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
DownloadOp,
|
||||
@@ -90,9 +90,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.INACTIVE,
|
||||
instance_type=InstanceStatus.INACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -101,7 +100,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
|
||||
@@ -124,9 +122,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.INACTIVE,
|
||||
instance_type=InstanceStatus.INACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -135,7 +132,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
|
||||
@@ -158,9 +154,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.INACTIVE,
|
||||
instance_type=InstanceStatus.INACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -169,7 +164,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: ReadyRunnerStatus()},
|
||||
@@ -184,9 +178,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE, # Either active or inactive should yield the same.
|
||||
instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -195,7 +188,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: AssignedRunnerStatus()},
|
||||
@@ -245,9 +237,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -256,7 +247,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: AssignedRunnerStatus()},
|
||||
@@ -291,9 +281,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -302,7 +291,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: ReadyRunnerStatus()},
|
||||
@@ -337,9 +325,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -349,7 +336,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))},
|
||||
@@ -382,9 +368,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -394,7 +379,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
|
||||
@@ -418,9 +402,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.INACTIVE,
|
||||
instance_type=InstanceStatus.INACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -429,7 +412,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus()},
|
||||
@@ -453,9 +435,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.INACTIVE,
|
||||
instance_type=InstanceStatus.INACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -464,7 +445,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: FailedRunnerStatus()},
|
||||
@@ -488,9 +468,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -499,7 +478,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus()},
|
||||
@@ -542,9 +520,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -554,7 +531,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
||||
@@ -587,9 +563,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -599,7 +574,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
||||
@@ -644,9 +618,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -656,7 +629,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()},
|
||||
@@ -701,9 +673,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -713,7 +684,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()},
|
||||
@@ -737,9 +707,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -748,7 +717,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: FailedRunnerStatus()},
|
||||
@@ -781,9 +749,8 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
@@ -793,7 +760,6 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
||||
@@ -825,19 +791,17 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
||||
instances={
|
||||
INSTANCE_1_ID: Instance(
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
||||
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
||||
},
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=MODEL_A_ID,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
||||
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
||||
},
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
||||
),
|
||||
hosts=[]
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
|
||||
@@ -884,7 +848,7 @@ def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.Mon
|
||||
if len(case.state.instances) == 1:
|
||||
instance_id = next(iter(case.state.instances))
|
||||
|
||||
shard_assignments = case.state.instances[instance_id].instance_params.shard_assignments
|
||||
shard_assignments = case.state.instances[instance_id].shard_assignments
|
||||
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
|
||||
|
||||
# Only add this runner if it belongs to our node
|
||||
|
||||
@@ -11,7 +11,7 @@ from shared.types.state import State
|
||||
from shared.types.tasks import TaskId
|
||||
from shared.types.worker.common import InstanceId, NodeStatus, RunnerId
|
||||
from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
|
||||
from shared.types.worker.instances import Instance, InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.ops import RunnerOp
|
||||
from shared.types.worker.runners import (
|
||||
AssignedRunnerStatus,
|
||||
@@ -148,14 +148,11 @@ def create_worker_state(
|
||||
runner_to_shard={runner_id: shard_metadata},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
instance_params = InstanceParams(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[],
|
||||
)
|
||||
instance = Instance(
|
||||
instance_id=instance_id,
|
||||
instance_params=instance_params,
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[],
|
||||
)
|
||||
instances[instance_id] = instance
|
||||
|
||||
@@ -198,14 +195,11 @@ def make_instance(
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
instance_params = InstanceParams(
|
||||
return Instance(
|
||||
instance_id=instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[],
|
||||
)
|
||||
return Instance(
|
||||
instance_id=instance_id,
|
||||
instance_params=instance_params,
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
)
|
||||
|
||||
### For worker plan tests
|
||||
Reference in New Issue
Block a user