From 6f8e3419d502ffa28c96ce65247973efc6f13e25 Mon Sep 17 00:00:00 2001 From: Seth Howes <71157822+sethhowes@users.noreply.github.com> Date: Thu, 24 Jul 2025 20:22:40 +0100 Subject: [PATCH] Placement strategy Co-authored-by: Alex Cheema --- master/api.py | 6 +- master/main.py | 15 +- master/placement.py | 83 +++++++++-- master/tests/conftest.py | 46 ++++++ master/tests/test_placement.py | 155 ++++++++++++++++++++ master/tests/test_placement_utils.py | 173 +++++++++++++++++++++++ master/tests/test_topology.py | 4 +- master/utils/placement_utils.py | 77 ++++++++++ shared/tests/test_sqlite_connector.py | 7 +- shared/topology.py | 9 ++ shared/types/common.py | 5 +- shared/types/events/chunks.py | 7 +- shared/types/events/commands.py | 21 +-- shared/types/profiling.py | 4 +- shared/types/state.py | 6 - shared/types/topology.py | 2 + shared/types/worker/shards.py | 5 +- worker/download/impl_shard_downloader.py | 4 +- worker/runner/runner_supervisor.py | 2 +- 19 files changed, 572 insertions(+), 59 deletions(-) create mode 100644 master/tests/conftest.py create mode 100644 master/tests/test_placement.py create mode 100644 master/tests/test_placement_utils.py create mode 100644 master/utils/placement_utils.py diff --git a/master/api.py b/master/api.py index dd99a5cf..e2a8428d 100644 --- a/master/api.py +++ b/master/api.py @@ -13,13 +13,13 @@ from shared.types.api import ( ChatCompletionResponse, StreamingChoiceResponse, ) +from shared.types.common import CommandId from shared.types.events import ChunkGenerated, Event from shared.types.events.chunks import TokenChunk from shared.types.events.commands import ( ChatCompletionCommand, Command, - CommandId, - CommandTypes, + CommandType, ) from shared.types.events.components import EventFromEventLog from shared.types.tasks import ChatCompletionTaskParams @@ -101,7 +101,7 @@ class API: request = ChatCompletionCommand( command_id=command_id, - command_type=CommandTypes.CHAT_COMPLETION, + command_type=CommandType.CHAT_COMPLETION, request_params=payload, ) self.command_buffer.append(request) diff --git a/master/main.py b/master/main.py index e9baf241..a253927d 100644 --- a/master/main.py +++ b/master/main.py @@ -14,10 +14,9 @@ from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogManager from shared.models.model_cards import MODEL_CARDS from shared.models.model_meta import get_model_meta -from shared.types.common import NodeId +from shared.types.common import CommandId, NodeId from shared.types.events import ( ChunkGenerated, - CommandId, InstanceCreated, TaskCreated, ) @@ -143,23 +142,23 @@ class Master: # TODO pass case CreateInstanceCommand(): - if next_command.model_id not in MODEL_CARDS: - raise ValueError(f"Model {next_command.model_id} not supported.") + if next_command.model_meta.model_id not in MODEL_CARDS: + raise ValueError(f"Model {next_command.model_meta.model_id} not supported.") # TODO: we should also support models that aren't in MODEL_CARDS # if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface - if next_command.model_id in MODEL_CARDS: - model_card = MODEL_CARDS[next_command.model_id] + if next_command.model_meta.model_id in MODEL_CARDS: + model_card = MODEL_CARDS[next_command.model_meta.model_id] model_meta = model_card.metadata else: - model_meta = await get_model_meta(next_command.model_id) + model_meta = await get_model_meta(next_command.model_meta.model_id) # TODO: how do we actually schedule an instance? TODO: @@@@@@π•Ύπ–Šπ–™π–@@@@@@ next_event = InstanceCreated( instance_id=InstanceId(), instance_params=InstanceParams( shard_assignments=ShardAssignments( - model_id=next_command.model_id, + model_id=next_command.model_meta.model_id, runner_to_shard={ RunnerId(): PipelineShardMetadata( model_meta=model_meta, diff --git a/master/placement.py b/master/placement.py index b9eb7d70..87d12c6e 100644 --- a/master/placement.py +++ b/master/placement.py @@ -1,24 +1,83 @@ -from queue import Queue -from typing import Mapping, Sequence +from collections.abc import Mapping +from copy import deepcopy +from functools import singledispatch +from typing import Sequence + +from master.utils.placement_utils import ( + filter_cycles_by_memory, + get_shard_assignments, + get_smallest_cycles, +) from shared.topology import Topology -from shared.types.events import Event -from shared.types.state import CachePolicy -from shared.types.tasks import Task -from shared.types.worker.instances import InstanceId, InstanceParams +from shared.types.events import Event, InstanceCreated, InstanceDeleted +from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand +from shared.types.worker.common import InstanceId +from shared.types.worker.instances import InstanceParams, TypeOfInstance +@singledispatch def get_instance_placements( - inbox: Queue[Task], - outbox: Queue[Task], + command: CreateInstanceCommand, topology: Topology, - current_instances: Mapping[InstanceId, InstanceParams], - cache_policy: CachePolicy, -) -> Mapping[InstanceId, InstanceParams]: ... + current_instances: dict[InstanceId, InstanceParams], +) -> dict[InstanceId, InstanceParams]: + available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances] + if command.model_meta.model_id in available_models: + raise ValueError(f"Instance for {command.model_meta.model_id} already exists") + + candidate_cycles = topology.get_cycles() + cycles = filter_cycles_by_memory(candidate_cycles, command.model_meta.storage_size_kilobytes) + if not cycles: + raise ValueError("No cycles found with sufficient memory") + smallest_cycles = get_smallest_cycles(cycles) + selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None)) + + shard_assignments = get_shard_assignments(command.model_meta, selected_cycle) + + instance_id = InstanceId() + target_instances = deepcopy(current_instances) + target_instances[instance_id] = InstanceParams( + shard_assignments=shard_assignments, + hosts=[] + ) + return target_instances + + +@get_instance_placements.register +def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]: + target_instances = deepcopy(current_instances) + if command.instance_id in target_instances: + del target_instances[command.instance_id] + return target_instances + raise ValueError(f"Instance {command.instance_id} not found") def get_transition_events( current_instances: Mapping[InstanceId, InstanceParams], target_instances: Mapping[InstanceId, InstanceParams], -) -> Sequence[Event]: ... +) -> Sequence[Event]: + events: list[Event] = [] + + # find instances to create + for instance_id, instance_params in target_instances.items(): + if instance_id not in current_instances: + events.append( + InstanceCreated( + instance_id=instance_id, + instance_params=instance_params, + instance_type=TypeOfInstance.ACTIVE + ) + ) + + # find instances to delete + for instance_id in current_instances: + if instance_id not in target_instances: + events.append( + InstanceDeleted( + instance_id=instance_id, + ) + ) + + return events diff --git a/master/tests/conftest.py b/master/tests/conftest.py new file mode 100644 index 00000000..6ab6bd92 --- /dev/null +++ b/master/tests/conftest.py @@ -0,0 +1,46 @@ +import pytest + +from shared.types.common import NodeId +from shared.types.profiling import ( + MemoryPerformanceProfile, + NodePerformanceProfile, + SystemPerformanceProfile, +) +from shared.types.topology import Connection, ConnectionProfile, Node + + +@pytest.fixture +def create_node(): + def _create_node(memory: int, node_id: NodeId | None = None) -> Node: + if node_id is None: + node_id = NodeId() + return Node( + node_id=node_id, + node_profile=NodePerformanceProfile( + model_id="test", + chip_id="test", + memory=MemoryPerformanceProfile( + ram_total=1000, + ram_available=memory, + swap_total=1000, + swap_available=1000 + ), + network_interfaces=[], + system=SystemPerformanceProfile(flops_fp16=1000) + ) + ) + + return _create_node + + +@pytest.fixture +def create_connection(): + def _create_connection(source_node_id: NodeId, sink_node_id: NodeId) -> Connection: + return Connection( + source_node_id=source_node_id, + sink_node_id=sink_node_id, + source_multiaddr="/ip4/127.0.0.1/tcp/1234", + sink_multiaddr="/ip4/127.0.0.1/tcp/1235", + connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000) + ) + return _create_connection \ No newline at end of file diff --git a/master/tests/test_placement.py b/master/tests/test_placement.py new file mode 100644 index 00000000..cf105b97 --- /dev/null +++ b/master/tests/test_placement.py @@ -0,0 +1,155 @@ +from typing import Callable + +import pytest + +from master.placement import get_instance_placements, get_transition_events +from shared.topology import Topology +from shared.types.common import CommandId, NodeId +from shared.types.events._events import ( + _EventType, # pyright: ignore[reportPrivateUsage] +) +from shared.types.events.commands import CreateInstanceCommand +from shared.types.models import ModelMetadata +from shared.types.topology import Connection, Node +from shared.types.worker.common import InstanceId +from shared.types.worker.instances import InstanceParams +from shared.types.worker.runners import ShardAssignments + + +@pytest.fixture +def topology() -> Topology: + return Topology() + +@pytest.fixture +def instance_params() -> InstanceParams: + return InstanceParams( + shard_assignments=ShardAssignments( + model_id="test-model", + runner_to_shard={}, + node_to_runner={} + ), + hosts=[] + ) + +@pytest.fixture +def model_meta() -> ModelMetadata: + return ModelMetadata( + model_id="test-model", + storage_size_kilobytes=1000, + pretty_name="Test Model", + n_layers=10 + ) + +def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand: + return CreateInstanceCommand( + command_id=CommandId(), + model_meta=model_meta + ) + + +@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [ + ((500, 500, 1000), 12, (3, 3, 6)), + ((500, 500, 500), 12, (4, 4, 4)), + ((312, 518, 1024), 12, (2, 3, 7)) +]) +def test_get_instance_placements_create_instance( + available_memory: tuple[int, int, int], + total_layers: int, + expected_layers: tuple[int, int, int], + topology: Topology, + model_meta: ModelMetadata, + create_node: Callable[[int, NodeId | None], Node], + create_connection: Callable[[NodeId, NodeId], Connection] +): + # arrange + model_meta.n_layers = total_layers + + create_instance_command = CreateInstanceCommand( + command_id=CommandId(), + model_meta=model_meta + ) + node_id_a = NodeId() + node_id_b = NodeId() + node_id_c = NodeId() + topology.add_node(create_node(available_memory[0], node_id_a), node_id_a) + topology.add_node(create_node(available_memory[1], node_id_b), node_id_b) + topology.add_node(create_node(available_memory[2], node_id_c), node_id_c) + topology.add_connection(create_connection(node_id_a, node_id_b)) + topology.add_connection(create_connection(node_id_b, node_id_c)) + topology.add_connection(create_connection(node_id_c, node_id_a)) + + # act + placements = get_instance_placements(create_instance_command, topology, {}) + + # assert + assert len(placements) == 1 + instance_id = list(placements.keys())[0] + instance_params = placements[instance_id] + assert instance_params.shard_assignments.model_id == model_meta.model_id + + runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a] + runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b] + runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c] + + shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a] + shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b] + shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c] + + assert shard_a.end_layer - shard_a.start_layer == expected_layers[0] + assert shard_b.end_layer - shard_b.start_layer == expected_layers[1] + assert shard_c.end_layer - shard_c.start_layer == expected_layers[2] + + shards = [shard_a, shard_b, shard_c] + shards_sorted = sorted(shards, key=lambda s: s.start_layer) + assert shards_sorted[0].start_layer == 0 + assert shards_sorted[-1].end_layer == total_layers + + +def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams): + # arrange + instance_id = InstanceId() + current_instances = { + instance_id: instance_params + } + target_instances = { + instance_id: instance_params + } + + # act + events = get_transition_events(current_instances, target_instances) + + # assert + assert len(events) == 0 + + +def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams): + # arrange + instance_id = InstanceId() + current_instances: dict[InstanceId, InstanceParams] = {} + target_instances: dict[InstanceId, InstanceParams] = { + instance_id: instance_params + } + + # act + events = get_transition_events(current_instances, target_instances) + + # assert + assert len(events) == 1 + assert events[0].event_type == _EventType.InstanceCreated + + +def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams): + # arrange + instance_id = InstanceId() + current_instances: dict[InstanceId, InstanceParams] = { + instance_id: instance_params + } + target_instances: dict[InstanceId, InstanceParams] = {} + + # act + events = get_transition_events(current_instances, target_instances) + + # assert + assert len(events) == 1 + assert events[0].event_type == _EventType.InstanceDeleted + assert events[0].instance_id == instance_id diff --git a/master/tests/test_placement_utils.py b/master/tests/test_placement_utils.py new file mode 100644 index 00000000..7dce222f --- /dev/null +++ b/master/tests/test_placement_utils.py @@ -0,0 +1,173 @@ +from typing import Callable + +import pytest + +from master.utils.placement_utils import ( + filter_cycles_by_memory, + get_shard_assignments, + get_smallest_cycles, +) +from shared.topology import Topology +from shared.types.common import NodeId +from shared.types.models import ModelMetadata +from shared.types.topology import Connection, Node + + +@pytest.fixture +def topology() -> Topology: + topology = Topology() + return topology + + +def test_filter_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]): + # arrange + node1_id = NodeId() + node2_id = NodeId() + + node1 = create_node(1000, node1_id) + node2 = create_node(1000, node2_id) + + topology.add_node(node1, node1_id) + topology.add_node(node2, node2_id) + + connection1 = create_connection(node1_id, node2_id) + connection2 = create_connection(node2_id, node1_id) + + topology.add_connection(connection1) + topology.add_connection(connection2) + + cycles = topology.get_cycles() + + # act + filtered_cycles = filter_cycles_by_memory(cycles, 1) + + # assert + assert len(filtered_cycles) == 1 + assert len(filtered_cycles[0]) == 2 + assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id} + + +def test_filter_cycles_by_insufficient_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]): + # arrange + node1_id = NodeId() + node2_id = NodeId() + + node1 = create_node(1000, node1_id) + node2 = create_node(1000, node2_id) + + topology.add_node(node1, node1_id) + topology.add_node(node2, node2_id) + + connection1 = create_connection(node1_id, node2_id) + connection2 = create_connection(node2_id, node1_id) + + topology.add_connection(connection1) + topology.add_connection(connection2) + + # act + filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001) + + # assert + assert len(filtered_cycles) == 0 + + +def test_filter_multiple_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]): + # arrange + node_a_id = NodeId() + node_b_id = NodeId() + node_c_id = NodeId() + + node_a = create_node(500, node_a_id) + node_b = create_node(500, node_b_id) + node_c = create_node(1000, node_c_id) + + topology.add_node(node_a, node_a_id) + topology.add_node(node_b, node_b_id) + topology.add_node(node_c, node_c_id) + + topology.add_connection(create_connection(node_a_id, node_b_id)) + topology.add_connection(create_connection(node_b_id, node_a_id)) + + topology.add_connection(create_connection(node_a_id, node_c_id)) + topology.add_connection(create_connection(node_c_id, node_b_id)) + + cycles = topology.get_cycles() + + # act + filtered_cycles = filter_cycles_by_memory(cycles, 1500) + + # assert + assert len(filtered_cycles) == 1 + assert len(filtered_cycles[0]) == 3 + assert set(n.node_id for n in filtered_cycles[0]) == {node_a_id, node_b_id, node_c_id} + +def test_get_smallest_cycles(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]): + # arrange + node_a_id = NodeId() + node_b_id = NodeId() + node_c_id = NodeId() + + node_a = create_node(500, node_a_id) + node_b = create_node(500, node_b_id) + node_c = create_node(1000, node_c_id) + + topology.add_node(node_a, node_a_id) + topology.add_node(node_b, node_b_id) + topology.add_node(node_c, node_c_id) + + topology.add_connection(create_connection(node_a_id, node_b_id)) + topology.add_connection(create_connection(node_b_id, node_c_id)) + topology.add_connection(create_connection(node_c_id, node_a_id)) + topology.add_connection(create_connection(node_b_id, node_a_id)) + + # act + smallest_cycles = get_smallest_cycles(topology.get_cycles()) + + # assert + assert len(smallest_cycles) == 1 + assert len(smallest_cycles[0]) == 2 + assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id} + +@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [ + ((500, 500, 1000), 12, (3, 3, 6)), + ((500, 500, 500), 12, (4, 4, 4)), + ((312, 518, 1024), 12, (2, 3, 7)) +]) +def test_get_shard_assignments(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection], available_memory: tuple[int, int, int], total_layers: int, expected_layers: tuple[int, int, int]): + # arrange + node_a_id = NodeId() + node_b_id = NodeId() + node_c_id = NodeId() + + node_a = create_node(available_memory[0], node_a_id) + node_b = create_node(available_memory[1], node_b_id) + node_c = create_node(available_memory[2], node_c_id) + + topology.add_node(node_a, node_a_id) + topology.add_node(node_b, node_b_id) + topology.add_node(node_c, node_c_id) + + topology.add_connection(create_connection(node_a_id, node_b_id)) + topology.add_connection(create_connection(node_b_id, node_c_id)) + topology.add_connection(create_connection(node_c_id, node_a_id)) + topology.add_connection(create_connection(node_b_id, node_a_id)) + + model_meta = ModelMetadata( + model_id="test-model", + pretty_name="Test Model", + n_layers=total_layers, + storage_size_kilobytes=1000 + ) + cycles = topology.get_cycles() + selected_cycle = cycles[0] + + # act + shard_assignments = get_shard_assignments(model_meta, selected_cycle) + + # assert + runner_id_a = shard_assignments.node_to_runner[node_a_id] + runner_id_b = shard_assignments.node_to_runner[node_b_id] + runner_id_c = shard_assignments.node_to_runner[node_c_id] + assert shard_assignments.runner_to_shard[runner_id_c].end_layer - shard_assignments.runner_to_shard[runner_id_c].start_layer == expected_layers[2] + assert shard_assignments.runner_to_shard[runner_id_a].end_layer - shard_assignments.runner_to_shard[runner_id_a].start_layer == expected_layers[0] + assert shard_assignments.runner_to_shard[runner_id_b].end_layer - shard_assignments.runner_to_shard[runner_id_b].start_layer == expected_layers[1] diff --git a/master/tests/test_topology.py b/master/tests/test_topology.py index 5eaca934..1e395d2e 100644 --- a/master/tests/test_topology.py +++ b/master/tests/test_topology.py @@ -19,7 +19,7 @@ def connection() -> Connection: @pytest.fixture def node_profile() -> NodePerformanceProfile: - memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0) + memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000) system_profile = SystemPerformanceProfile(flops_fp16=1000) return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[], system=system_profile) @@ -57,7 +57,7 @@ def test_update_node_profile(topology: Topology, node_profile: NodePerformancePr topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile), node_id=connection.sink_node_id) topology.add_connection(connection) - new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000)) + new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000)) # act topology.update_node_profile(connection.source_node_id, node_profile=new_node_profile) diff --git a/master/utils/placement_utils.py b/master/utils/placement_utils.py new file mode 100644 index 00000000..30d96725 --- /dev/null +++ b/master/utils/placement_utils.py @@ -0,0 +1,77 @@ +from typing import TypeGuard, cast + +from pydantic import BaseModel + +from shared.types.common import NodeId +from shared.types.models import ModelMetadata +from shared.types.profiling import NodePerformanceProfile +from shared.types.topology import Node +from shared.types.worker.common import RunnerId +from shared.types.worker.runners import ShardAssignments +from shared.types.worker.shards import PipelineShardMetadata + + +class NodeWithProfile(BaseModel): + node_id: NodeId + node_profile: NodePerformanceProfile + +def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]: + return all(node.node_profile is not None for node in nodes) + +def filter_cycles_by_memory(cycles: list[list[Node]], required_memory: int) -> list[list[Node]]: + filtered_cycles: list[list[Node]] = [] + for cycle in cycles: + if not narrow_all_nodes(cycle): + continue + + total_mem = sum(node.node_profile.memory.ram_available for node in cycle) + if total_mem >= required_memory: + filtered_cycles.append(cast(list[Node], cycle)) + return filtered_cycles + + +def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]: + min_nodes = min(len(cycle) for cycle in cycles) + return [cycle for cycle in cycles if len(cycle) == min_nodes] + +def get_shard_assignments( + model_meta: ModelMetadata, + selected_cycle: list[Node], +) -> ShardAssignments: + if not narrow_all_nodes(selected_cycle): + raise ValueError("All nodes must have profiles to create shard assignments") + + cycle_memory = sum(node.node_profile.memory.ram_available for node in selected_cycle) + total_layers = model_meta.n_layers + runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} + node_to_runner: dict[NodeId, RunnerId] = {} + + layers_assigned = 0 + for i, node in enumerate(selected_cycle): + if i == len(selected_cycle) - 1: + node_layers = total_layers - layers_assigned + else: + node_layers = round(total_layers * (node.node_profile.memory.ram_available / cycle_memory)) + node_layers = max(1, node_layers) + + runner_id = RunnerId() + shard = PipelineShardMetadata( + model_meta=model_meta, + device_rank=i, + world_size=len(selected_cycle), + start_layer=layers_assigned, + end_layer=layers_assigned + node_layers, + n_layers=total_layers + ) + + runner_to_shard[runner_id] = shard + node_to_runner[node.node_id] = runner_id + layers_assigned += node_layers + + shard_assignments = ShardAssignments( + model_id=model_meta.model_id, + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner + ) + + return shard_assignments diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index 687ee230..5963cc8e 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -10,11 +10,8 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig -from shared.types.common import NodeId -from shared.types.events import ( - ChunkGenerated, - CommandId, -) +from shared.types.common import CommandId, NodeId +from shared.types.events import ChunkGenerated from shared.types.events.chunks import ChunkType, TokenChunk # Type ignore comment for all protected member access in this test file diff --git a/shared/topology.py b/shared/topology.py index 289912f3..c44c717e 100644 --- a/shared/topology.py +++ b/shared/topology.py @@ -84,6 +84,15 @@ class Topology(TopologyProto): del self._edge_id_to_rx_id_map[connection] del self._rx_id_to_node_id_map[rx_idx] + def get_cycles(self) -> list[list[Node]]: + cycle_idxs = rx.simple_cycles(self._graph) + cycles: list[list[Node]] = [] + for cycle_idx in cycle_idxs: + cycle = [self._graph[idx] for idx in cycle_idx] + cycles.append(cycle) + + return cycles + def _is_bridge(self, connection: Connection) -> bool: edge_idx = self._edge_id_to_rx_id_map[connection] graph_copy = self._graph.copy().to_undirected() diff --git a/shared/types/common.py b/shared/types/common.py index 347e7864..58051656 100644 --- a/shared/types/common.py +++ b/shared/types/common.py @@ -19,4 +19,7 @@ class ID(str): return handler.generate_schema(str) class NodeId(ID): - pass \ No newline at end of file + pass + +class CommandId(ID): + pass diff --git a/shared/types/events/chunks.py b/shared/types/events/chunks.py index 67e0587d..f060075c 100644 --- a/shared/types/events/chunks.py +++ b/shared/types/events/chunks.py @@ -4,15 +4,10 @@ from typing import Annotated, Literal from pydantic import BaseModel, Field, TypeAdapter from shared.openai_compat import FinishReason -from shared.types.common import ID +from shared.types.common import CommandId from shared.types.models import ModelId -class CommandId(ID): - """ - Newtype around `ID` for command IDs - """ - class ChunkType(str, Enum): token = "token" image = "image" diff --git a/shared/types/events/commands.py b/shared/types/events/commands.py index ae96f6d2..ae17100d 100644 --- a/shared/types/events/commands.py +++ b/shared/types/events/commands.py @@ -4,35 +4,36 @@ from typing import Annotated, Callable, Literal, Sequence from pydantic import BaseModel, Field, TypeAdapter from shared.types.api import ChatCompletionTaskParams +from shared.types.common import CommandId from shared.types.events import Event -from shared.types.events.chunks import CommandId +from shared.types.models import ModelMetadata from shared.types.state import InstanceId, State # TODO: We need to have a distinction between create instance and spin up instance. -class CommandTypes(str, Enum): +class CommandType(str, Enum): CHAT_COMPLETION = "CHAT_COMPLETION" CREATE_INSTANCE = "CREATE_INSTANCE" DELETE_INSTANCE = "DELETE_INSTANCE" -class _BaseCommand[T: CommandTypes](BaseModel): +class _BaseCommand[T: CommandType](BaseModel): command_id: CommandId command_type: T -class ChatCompletionCommand(_BaseCommand[CommandTypes.CHAT_COMPLETION]): - command_type: Literal[CommandTypes.CHAT_COMPLETION] = CommandTypes.CHAT_COMPLETION +class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]): + command_type: Literal[CommandType.CHAT_COMPLETION] = CommandType.CHAT_COMPLETION request_params: ChatCompletionTaskParams -class CreateInstanceCommand(_BaseCommand[CommandTypes.CREATE_INSTANCE]): - command_type: Literal[CommandTypes.CREATE_INSTANCE] = CommandTypes.CREATE_INSTANCE - model_id: str +class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]): + command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE + model_meta: ModelMetadata -class DeleteInstanceCommand(_BaseCommand[CommandTypes.DELETE_INSTANCE]): - command_type: Literal[CommandTypes.DELETE_INSTANCE] = CommandTypes.DELETE_INSTANCE +class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]): + command_type: Literal[CommandType.DELETE_INSTANCE] = CommandType.DELETE_INSTANCE instance_id: InstanceId diff --git a/shared/types/profiling.py b/shared/types/profiling.py index ff1af45d..841d68ee 100644 --- a/shared/types/profiling.py +++ b/shared/types/profiling.py @@ -3,9 +3,9 @@ from pydantic import BaseModel, Field class MemoryPerformanceProfile(BaseModel): ram_total: int - ram_used: int + ram_available: int swap_total: int - swap_used: int + swap_available: int class SystemPerformanceProfile(BaseModel): diff --git a/shared/types/state.py b/shared/types/state.py index 0129d925..769ad319 100644 --- a/shared/types/state.py +++ b/shared/types/state.py @@ -1,5 +1,4 @@ from collections.abc import Mapping, Sequence -from enum import Enum from pydantic import BaseModel, ConfigDict, Field @@ -12,10 +11,6 @@ from shared.types.worker.instances import BaseInstance from shared.types.worker.runners import RunnerId, RunnerStatus -class CachePolicy(str, Enum): - KEEP_ALL = "KEEP_ALL" - - class State(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) node_status: Mapping[NodeId, NodeStatus] = {} @@ -25,5 +20,4 @@ class State(BaseModel): node_profiles: Mapping[NodeId, NodePerformanceProfile] = {} topology: Topology = Topology() history: Sequence[Topology] = [] - cache_policy: CachePolicy = CachePolicy.KEEP_ALL last_event_applied_idx: int = Field(default=0, ge=0) diff --git a/shared/types/topology.py b/shared/types/topology.py index c41907ec..0dac5c08 100644 --- a/shared/types/topology.py +++ b/shared/types/topology.py @@ -63,3 +63,5 @@ class TopologyProto(Protocol): def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: ... def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None: ... + + def get_cycles(self) -> list[list[Node]]: ... diff --git a/shared/types/worker/shards.py b/shared/types/worker/shards.py index 3bc8b16d..2ef7c8ae 100644 --- a/shared/types/worker/shards.py +++ b/shared/types/worker/shards.py @@ -29,6 +29,9 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]): class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]): """ Pipeline parallelism shard meta. + + Layers are represented as a half-open interval [start_layer, end_layer), + where start_layer is inclusive and end_layer is exclusive. """ partition_strategy: Literal[PartitionStrategy.pipeline] = Field( @@ -44,7 +47,7 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline @property def is_last_layer(self) -> bool: - return self.end_layer == self.n_layers - 1 + return self.end_layer == self.n_layers def __hash__(self) -> int: return hash((self.model_meta.model_id, self.start_layer, self.end_layer, self.n_layers)) diff --git a/worker/download/impl_shard_downloader.py b/worker/download/impl_shard_downloader.py index 1ff6d081..3843107e 100644 --- a/worker/download/impl_shard_downloader.py +++ b/worker/download/impl_shard_downloader.py @@ -25,7 +25,7 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]: device_rank=0, world_size=1, start_layer=0, - end_layer=model_meta.n_layers - 1, + end_layer=model_meta.n_layers, n_layers=model_meta.n_layers, ) @@ -39,7 +39,7 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]: device_rank=base_shard.device_rank, world_size=base_shard.world_size, start_layer=base_shard.start_layer, - end_layer=base_shard.n_layers - 1, + end_layer=base_shard.n_layers, n_layers=base_shard.n_layers, ) diff --git a/worker/runner/runner_supervisor.py b/worker/runner/runner_supervisor.py index d2b556d4..3d1b0553 100644 --- a/worker/runner/runner_supervisor.py +++ b/worker/runner/runner_supervisor.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator from types import CoroutineType from typing import Any, Callable -from shared.types.events import CommandId +from shared.types.common import CommandId from shared.types.events.chunks import GenerationChunk, TokenChunk from shared.types.tasks import ChatCompletionTaskParams, Task from shared.types.worker.commands_runner import (