mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Placement strategy
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
@@ -13,13 +13,13 @@ from shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.events import ChunkGenerated, Event
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
Command,
|
||||
CommandId,
|
||||
CommandTypes,
|
||||
CommandType,
|
||||
)
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -101,7 +101,7 @@ class API:
|
||||
|
||||
request = ChatCompletionCommand(
|
||||
command_id=command_id,
|
||||
command_type=CommandTypes.CHAT_COMPLETION,
|
||||
command_type=CommandType.CHAT_COMPLETION,
|
||||
request_params=payload,
|
||||
)
|
||||
self.command_buffer.append(request)
|
||||
|
||||
@@ -14,10 +14,9 @@ from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.common import CommandId, NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
CommandId,
|
||||
InstanceCreated,
|
||||
TaskCreated,
|
||||
)
|
||||
@@ -143,23 +142,23 @@ class Master:
|
||||
# TODO
|
||||
pass
|
||||
case CreateInstanceCommand():
|
||||
if next_command.model_id not in MODEL_CARDS:
|
||||
raise ValueError(f"Model {next_command.model_id} not supported.")
|
||||
if next_command.model_meta.model_id not in MODEL_CARDS:
|
||||
raise ValueError(f"Model {next_command.model_meta.model_id} not supported.")
|
||||
|
||||
# TODO: we should also support models that aren't in MODEL_CARDS
|
||||
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
|
||||
if next_command.model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[next_command.model_id]
|
||||
if next_command.model_meta.model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[next_command.model_meta.model_id]
|
||||
model_meta = model_card.metadata
|
||||
else:
|
||||
model_meta = await get_model_meta(next_command.model_id)
|
||||
model_meta = await get_model_meta(next_command.model_meta.model_id)
|
||||
|
||||
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
|
||||
next_event = InstanceCreated(
|
||||
instance_id=InstanceId(),
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=next_command.model_id,
|
||||
model_id=next_command.model_meta.model_id,
|
||||
runner_to_shard={
|
||||
RunnerId(): PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
|
||||
@@ -1,24 +1,83 @@
|
||||
from queue import Queue
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Sequence
|
||||
|
||||
from master.utils.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from shared.topology import Topology
|
||||
from shared.types.events import Event
|
||||
from shared.types.state import CachePolicy
|
||||
from shared.types.tasks import Task
|
||||
from shared.types.worker.instances import InstanceId, InstanceParams
|
||||
from shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
|
||||
|
||||
@singledispatch
|
||||
def get_instance_placements(
|
||||
inbox: Queue[Task],
|
||||
outbox: Queue[Task],
|
||||
command: CreateInstanceCommand,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
cache_policy: CachePolicy,
|
||||
) -> Mapping[InstanceId, InstanceParams]: ...
|
||||
current_instances: dict[InstanceId, InstanceParams],
|
||||
) -> dict[InstanceId, InstanceParams]:
|
||||
available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances]
|
||||
if command.model_meta.model_id in available_models:
|
||||
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
|
||||
|
||||
candidate_cycles = topology.get_cycles()
|
||||
cycles = filter_cycles_by_memory(candidate_cycles, command.model_meta.storage_size_kilobytes)
|
||||
if not cycles:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
|
||||
|
||||
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = deepcopy(current_instances)
|
||||
target_instances[instance_id] = InstanceParams(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[]
|
||||
)
|
||||
return target_instances
|
||||
|
||||
|
||||
@get_instance_placements.register
|
||||
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]:
|
||||
target_instances = deepcopy(current_instances)
|
||||
if command.instance_id in target_instances:
|
||||
del target_instances[command.instance_id]
|
||||
return target_instances
|
||||
raise ValueError(f"Instance {command.instance_id} not found")
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
target_instances: Mapping[InstanceId, InstanceParams],
|
||||
) -> Sequence[Event]: ...
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
# find instances to create
|
||||
for instance_id, instance_params in target_instances.items():
|
||||
if instance_id not in current_instances:
|
||||
events.append(
|
||||
InstanceCreated(
|
||||
instance_id=instance_id,
|
||||
instance_params=instance_params,
|
||||
instance_type=TypeOfInstance.ACTIVE
|
||||
)
|
||||
)
|
||||
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
46
master/tests/conftest.py
Normal file
46
master/tests/conftest.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from shared.types.topology import Connection, ConnectionProfile, Node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> Node:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return Node(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(flops_fp16=1000)
|
||||
)
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_connection():
|
||||
def _create_connection(source_node_id: NodeId, sink_node_id: NodeId) -> Connection:
|
||||
return Connection(
|
||||
source_node_id=source_node_id,
|
||||
sink_node_id=sink_node_id,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
)
|
||||
return _create_connection
|
||||
155
master/tests/test_placement.py
Normal file
155
master/tests/test_placement.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from master.placement import get_instance_placements, get_transition_events
|
||||
from shared.topology import Topology
|
||||
from shared.types.common import CommandId, NodeId
|
||||
from shared.types.events._events import (
|
||||
_EventType, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from shared.types.events.commands import CreateInstanceCommand
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.topology import Connection, Node
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import InstanceParams
|
||||
from shared.types.worker.runners import ShardAssignments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
@pytest.fixture
|
||||
def instance_params() -> InstanceParams:
|
||||
return InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id="test-model",
|
||||
runner_to_shard={},
|
||||
node_to_runner={}
|
||||
),
|
||||
hosts=[]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def model_meta() -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model_id="test-model",
|
||||
storage_size_kilobytes=1000,
|
||||
pretty_name="Test Model",
|
||||
n_layers=10
|
||||
)
|
||||
|
||||
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
|
||||
return CreateInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7))
|
||||
])
|
||||
def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection]
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta
|
||||
)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
topology.add_node(create_node(available_memory[0], node_id_a), node_id_a)
|
||||
topology.add_node(create_node(available_memory[1], node_id_b), node_id_b)
|
||||
topology.add_node(create_node(available_memory[2], node_id_c), node_id_c)
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
|
||||
# act
|
||||
placements = get_instance_placements(create_instance_command, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance_params = placements[instance_id]
|
||||
assert instance_params.shard_assignments.model_id == model_meta.model_id
|
||||
|
||||
runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a]
|
||||
runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b]
|
||||
runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c]
|
||||
|
||||
shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a]
|
||||
shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b]
|
||||
shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c]
|
||||
|
||||
assert shard_a.end_layer - shard_a.start_layer == expected_layers[0]
|
||||
assert shard_b.end_layer - shard_b.start_layer == expected_layers[1]
|
||||
assert shard_c.end_layer - shard_c.start_layer == expected_layers[2]
|
||||
|
||||
shards = [shard_a, shard_b, shard_c]
|
||||
shards_sorted = sorted(shards, key=lambda s: s.start_layer)
|
||||
assert shards_sorted[0].start_layer == 0
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances = {
|
||||
instance_id: instance_params
|
||||
}
|
||||
target_instances = {
|
||||
instance_id: instance_params
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, InstanceParams] = {}
|
||||
target_instances: dict[InstanceId, InstanceParams] = {
|
||||
instance_id: instance_params
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == _EventType.InstanceCreated
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, InstanceParams] = {
|
||||
instance_id: instance_params
|
||||
}
|
||||
target_instances: dict[InstanceId, InstanceParams] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == _EventType.InstanceDeleted
|
||||
assert events[0].instance_id == instance_id
|
||||
173
master/tests/test_placement_utils.py
Normal file
173
master/tests/test_placement_utils.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from master.utils.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from shared.topology import Topology
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.topology import Connection, Node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
|
||||
node1 = create_node(1000, node1_id)
|
||||
node2 = create_node(1000, node2_id)
|
||||
|
||||
topology.add_node(node1, node1_id)
|
||||
topology.add_node(node2, node2_id)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, 1)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 2
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
|
||||
node1 = create_node(1000, node1_id)
|
||||
node2 = create_node(1000, node2_id)
|
||||
|
||||
topology.add_node(node1, node1_id)
|
||||
topology.add_node(node2, node2_id)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
topology.add_node(node_a, node_a_id)
|
||||
topology.add_node(node_b, node_b_id)
|
||||
topology.add_node(node_c, node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, 1500)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 3
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node_a_id, node_b_id, node_c_id}
|
||||
|
||||
def test_get_smallest_cycles(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
topology.add_node(node_a, node_a_id)
|
||||
topology.add_node(node_b, node_b_id)
|
||||
topology.add_node(node_c, node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
|
||||
@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7))
|
||||
])
|
||||
def test_get_shard_assignments(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection], available_memory: tuple[int, int, int], total_layers: int, expected_layers: tuple[int, int, int]):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(available_memory[0], node_a_id)
|
||||
node_b = create_node(available_memory[1], node_b_id)
|
||||
node_c = create_node(available_memory[2], node_c_id)
|
||||
|
||||
topology.add_node(node_a, node_a_id)
|
||||
topology.add_node(node_b, node_b_id)
|
||||
topology.add_node(node_c, node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id="test-model",
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size_kilobytes=1000
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(model_meta, selected_cycle)
|
||||
|
||||
# assert
|
||||
runner_id_a = shard_assignments.node_to_runner[node_a_id]
|
||||
runner_id_b = shard_assignments.node_to_runner[node_b_id]
|
||||
runner_id_c = shard_assignments.node_to_runner[node_c_id]
|
||||
assert shard_assignments.runner_to_shard[runner_id_c].end_layer - shard_assignments.runner_to_shard[runner_id_c].start_layer == expected_layers[2]
|
||||
assert shard_assignments.runner_to_shard[runner_id_a].end_layer - shard_assignments.runner_to_shard[runner_id_a].start_layer == expected_layers[0]
|
||||
assert shard_assignments.runner_to_shard[runner_id_b].end_layer - shard_assignments.runner_to_shard[runner_id_b].start_layer == expected_layers[1]
|
||||
@@ -19,7 +19,7 @@ def connection() -> Connection:
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0)
|
||||
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000)
|
||||
system_profile = SystemPerformanceProfile(flops_fp16=1000)
|
||||
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[], system=system_profile)
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_update_node_profile(topology: Topology, node_profile: NodePerformancePr
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile), node_id=connection.sink_node_id)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
|
||||
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
|
||||
|
||||
# act
|
||||
topology.update_node_profile(connection.source_node_id, node_profile=new_node_profile)
|
||||
|
||||
77
master/utils/placement_utils.py
Normal file
77
master/utils/placement_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.topology import Node
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.runners import ShardAssignments
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
class NodeWithProfile(BaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
def filter_cycles_by_memory(cycles: list[list[Node]], required_memory: int) -> list[list[Node]]:
|
||||
filtered_cycles: list[list[Node]] = []
|
||||
for cycle in cycles:
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(node.node_profile.memory.ram_available for node in cycle)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cast(list[Node], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[Node],
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
|
||||
cycle_memory = sum(node.node_profile.memory.ram_available for node in selected_cycle)
|
||||
total_layers = model_meta.n_layers
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
layers_assigned = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(total_layers * (node.node_profile.memory.ram_available / cycle_memory))
|
||||
node_layers = max(1, node_layers)
|
||||
|
||||
runner_id = RunnerId()
|
||||
shard = PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=len(selected_cycle),
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner
|
||||
)
|
||||
|
||||
return shard_assignments
|
||||
@@ -10,11 +10,8 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
CommandId,
|
||||
)
|
||||
from shared.types.common import CommandId, NodeId
|
||||
from shared.types.events import ChunkGenerated
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk
|
||||
|
||||
# Type ignore comment for all protected member access in this test file
|
||||
|
||||
@@ -84,6 +84,15 @@ class Topology(TopologyProto):
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def get_cycles(self) -> list[list[Node]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[Node]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def _is_bridge(self, connection: Connection) -> bool:
|
||||
edge_idx = self._edge_id_to_rx_id_map[connection]
|
||||
graph_copy = self._graph.copy().to_undirected()
|
||||
|
||||
@@ -19,4 +19,7 @@ class ID(str):
|
||||
return handler.generate_schema(str)
|
||||
|
||||
class NodeId(ID):
|
||||
pass
|
||||
pass
|
||||
|
||||
class CommandId(ID):
|
||||
pass
|
||||
|
||||
@@ -4,15 +4,10 @@ from typing import Annotated, Literal
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.common import ID
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.models import ModelId
|
||||
|
||||
|
||||
class CommandId(ID):
|
||||
"""
|
||||
Newtype around `ID` for command IDs
|
||||
"""
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
token = "token"
|
||||
image = "image"
|
||||
|
||||
@@ -4,35 +4,36 @@ from typing import Annotated, Callable, Literal, Sequence
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.types.api import ChatCompletionTaskParams
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.events import Event
|
||||
from shared.types.events.chunks import CommandId
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.state import InstanceId, State
|
||||
|
||||
|
||||
# TODO: We need to have a distinction between create instance and spin up instance.
|
||||
class CommandTypes(str, Enum):
|
||||
class CommandType(str, Enum):
|
||||
CHAT_COMPLETION = "CHAT_COMPLETION"
|
||||
CREATE_INSTANCE = "CREATE_INSTANCE"
|
||||
DELETE_INSTANCE = "DELETE_INSTANCE"
|
||||
|
||||
|
||||
class _BaseCommand[T: CommandTypes](BaseModel):
|
||||
class _BaseCommand[T: CommandType](BaseModel):
|
||||
command_id: CommandId
|
||||
command_type: T
|
||||
|
||||
|
||||
class ChatCompletionCommand(_BaseCommand[CommandTypes.CHAT_COMPLETION]):
|
||||
command_type: Literal[CommandTypes.CHAT_COMPLETION] = CommandTypes.CHAT_COMPLETION
|
||||
class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]):
|
||||
command_type: Literal[CommandType.CHAT_COMPLETION] = CommandType.CHAT_COMPLETION
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class CreateInstanceCommand(_BaseCommand[CommandTypes.CREATE_INSTANCE]):
|
||||
command_type: Literal[CommandTypes.CREATE_INSTANCE] = CommandTypes.CREATE_INSTANCE
|
||||
model_id: str
|
||||
class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]):
|
||||
command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceCommand(_BaseCommand[CommandTypes.DELETE_INSTANCE]):
|
||||
command_type: Literal[CommandTypes.DELETE_INSTANCE] = CommandTypes.DELETE_INSTANCE
|
||||
class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):
|
||||
command_type: Literal[CommandType.DELETE_INSTANCE] = CommandType.DELETE_INSTANCE
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class MemoryPerformanceProfile(BaseModel):
|
||||
ram_total: int
|
||||
ram_used: int
|
||||
ram_available: int
|
||||
swap_total: int
|
||||
swap_used: int
|
||||
swap_available: int
|
||||
|
||||
|
||||
class SystemPerformanceProfile(BaseModel):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -12,10 +11,6 @@ from shared.types.worker.instances import BaseInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
|
||||
class CachePolicy(str, Enum):
|
||||
KEEP_ALL = "KEEP_ALL"
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
node_status: Mapping[NodeId, NodeStatus] = {}
|
||||
@@ -25,5 +20,4 @@ class State(BaseModel):
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
|
||||
topology: Topology = Topology()
|
||||
history: Sequence[Topology] = []
|
||||
cache_policy: CachePolicy = CachePolicy.KEEP_ALL
|
||||
last_event_applied_idx: int = Field(default=0, ge=0)
|
||||
|
||||
@@ -63,3 +63,5 @@ class TopologyProto(Protocol):
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: ...
|
||||
|
||||
def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None: ...
|
||||
|
||||
def get_cycles(self) -> list[list[Node]]: ...
|
||||
|
||||
@@ -29,6 +29,9 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
|
||||
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
|
||||
"""
|
||||
Pipeline parallelism shard meta.
|
||||
|
||||
Layers are represented as a half-open interval [start_layer, end_layer),
|
||||
where start_layer is inclusive and end_layer is exclusive.
|
||||
"""
|
||||
|
||||
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
|
||||
@@ -44,7 +47,7 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
|
||||
|
||||
@property
|
||||
def is_last_layer(self) -> bool:
|
||||
return self.end_layer == self.n_layers - 1
|
||||
return self.end_layer == self.n_layers
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.model_meta.model_id, self.start_layer, self.end_layer, self.n_layers))
|
||||
|
||||
@@ -25,7 +25,7 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=model_meta.n_layers - 1,
|
||||
end_layer=model_meta.n_layers,
|
||||
n_layers=model_meta.n_layers,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
|
||||
device_rank=base_shard.device_rank,
|
||||
world_size=base_shard.world_size,
|
||||
start_layer=base_shard.start_layer,
|
||||
end_layer=base_shard.n_layers - 1,
|
||||
end_layer=base_shard.n_layers,
|
||||
n_layers=base_shard.n_layers,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
|
||||
from types import CoroutineType
|
||||
from typing import Any, Callable
|
||||
|
||||
from shared.types.events import CommandId
|
||||
from shared.types.common import CommandId
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
||||
from shared.types.tasks import ChatCompletionTaskParams, Task
|
||||
from shared.types.worker.commands_runner import (
|
||||
|
||||
Reference in New Issue
Block a user