Placement strategy

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
Seth Howes
2025-07-24 20:22:40 +01:00
committed by GitHub
parent 4c0e4ef853
commit 6f8e3419d5
19 changed files with 572 additions and 59 deletions

View File

@@ -13,13 +13,13 @@ from shared.types.api import (
ChatCompletionResponse,
StreamingChoiceResponse,
)
from shared.types.common import CommandId
from shared.types.events import ChunkGenerated, Event
from shared.types.events.chunks import TokenChunk
from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CommandId,
CommandTypes,
CommandType,
)
from shared.types.events.components import EventFromEventLog
from shared.types.tasks import ChatCompletionTaskParams
@@ -101,7 +101,7 @@ class API:
request = ChatCompletionCommand(
command_id=command_id,
command_type=CommandTypes.CHAT_COMPLETION,
command_type=CommandType.CHAT_COMPLETION,
request_params=payload,
)
self.command_buffer.append(request)

View File

@@ -14,10 +14,9 @@ from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.common import NodeId
from shared.types.common import CommandId, NodeId
from shared.types.events import (
ChunkGenerated,
CommandId,
InstanceCreated,
TaskCreated,
)
@@ -143,23 +142,23 @@ class Master:
# TODO
pass
case CreateInstanceCommand():
if next_command.model_id not in MODEL_CARDS:
raise ValueError(f"Model {next_command.model_id} not supported.")
if next_command.model_meta.model_id not in MODEL_CARDS:
raise ValueError(f"Model {next_command.model_meta.model_id} not supported.")
# TODO: we should also support models that aren't in MODEL_CARDS
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
if next_command.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[next_command.model_id]
if next_command.model_meta.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[next_command.model_meta.model_id]
model_meta = model_card.metadata
else:
model_meta = await get_model_meta(next_command.model_id)
model_meta = await get_model_meta(next_command.model_meta.model_id)
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
next_event = InstanceCreated(
instance_id=InstanceId(),
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=next_command.model_id,
model_id=next_command.model_meta.model_id,
runner_to_shard={
RunnerId(): PipelineShardMetadata(
model_meta=model_meta,

View File

@@ -1,24 +1,83 @@
from queue import Queue
from typing import Mapping, Sequence
from collections.abc import Mapping
from copy import deepcopy
from functools import singledispatch
from typing import Sequence
from master.utils.placement_utils import (
filter_cycles_by_memory,
get_shard_assignments,
get_smallest_cycles,
)
from shared.topology import Topology
from shared.types.events import Event
from shared.types.state import CachePolicy
from shared.types.tasks import Task
from shared.types.worker.instances import InstanceId, InstanceParams
from shared.types.events import Event, InstanceCreated, InstanceDeleted
from shared.types.events.commands import CreateInstanceCommand, DeleteInstanceCommand
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import InstanceParams, TypeOfInstance
@singledispatch
def get_instance_placements(
inbox: Queue[Task],
outbox: Queue[Task],
command: CreateInstanceCommand,
topology: Topology,
current_instances: Mapping[InstanceId, InstanceParams],
cache_policy: CachePolicy,
) -> Mapping[InstanceId, InstanceParams]: ...
current_instances: dict[InstanceId, InstanceParams],
) -> dict[InstanceId, InstanceParams]:
available_models = [current_instances[instance].shard_assignments.model_id for instance in current_instances]
if command.model_meta.model_id in available_models:
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
candidate_cycles = topology.get_cycles()
cycles = filter_cycles_by_memory(candidate_cycles, command.model_meta.storage_size_kilobytes)
if not cycles:
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles)
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)
instance_id = InstanceId()
target_instances = deepcopy(current_instances)
target_instances[instance_id] = InstanceParams(
shard_assignments=shard_assignments,
hosts=[]
)
return target_instances
@get_instance_placements.register
def _(command: DeleteInstanceCommand, topology: Topology, current_instances: dict[InstanceId, InstanceParams]) -> dict[InstanceId, InstanceParams]:
target_instances = deepcopy(current_instances)
if command.instance_id in target_instances:
del target_instances[command.instance_id]
return target_instances
raise ValueError(f"Instance {command.instance_id} not found")
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceParams],
target_instances: Mapping[InstanceId, InstanceParams],
) -> Sequence[Event]: ...
) -> Sequence[Event]:
events: list[Event] = []
# find instances to create
for instance_id, instance_params in target_instances.items():
if instance_id not in current_instances:
events.append(
InstanceCreated(
instance_id=instance_id,
instance_params=instance_params,
instance_type=TypeOfInstance.ACTIVE
)
)
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
events.append(
InstanceDeleted(
instance_id=instance_id,
)
)
return events

46
master/tests/conftest.py Normal file
View File

@@ -0,0 +1,46 @@
import pytest
from shared.types.common import NodeId
from shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from shared.types.topology import Connection, ConnectionProfile, Node
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> Node:
if node_id is None:
node_id = NodeId()
return Node(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
memory=MemoryPerformanceProfile(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=1000)
)
)
return _create_node
@pytest.fixture
def create_connection():
def _create_connection(source_node_id: NodeId, sink_node_id: NodeId) -> Connection:
return Connection(
source_node_id=source_node_id,
sink_node_id=sink_node_id,
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
)
return _create_connection

View File

@@ -0,0 +1,155 @@
from typing import Callable
import pytest
from master.placement import get_instance_placements, get_transition_events
from shared.topology import Topology
from shared.types.common import CommandId, NodeId
from shared.types.events._events import (
_EventType, # pyright: ignore[reportPrivateUsage]
)
from shared.types.events.commands import CreateInstanceCommand
from shared.types.models import ModelMetadata
from shared.types.topology import Connection, Node
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import InstanceParams
from shared.types.worker.runners import ShardAssignments
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance_params() -> InstanceParams:
return InstanceParams(
shard_assignments=ShardAssignments(
model_id="test-model",
runner_to_shard={},
node_to_runner={}
),
hosts=[]
)
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id="test-model",
storage_size_kilobytes=1000,
pretty_name="Test Model",
n_layers=10
)
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
return CreateInstanceCommand(
command_id=CommandId(),
model_meta=model_meta
)
@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7))
])
def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], Node],
create_connection: Callable[[NodeId, NodeId], Connection]
):
# arrange
model_meta.n_layers = total_layers
create_instance_command = CreateInstanceCommand(
command_id=CommandId(),
model_meta=model_meta
)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0], node_id_a), node_id_a)
topology.add_node(create_node(available_memory[1], node_id_b), node_id_b)
topology.add_node(create_node(available_memory[2], node_id_c), node_id_c)
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
# act
placements = get_instance_placements(create_instance_command, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance_params = placements[instance_id]
assert instance_params.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance_params.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance_params.shard_assignments.node_to_runner[node_id_b]
runner_id_c = instance_params.shard_assignments.node_to_runner[node_id_c]
shard_a = instance_params.shard_assignments.runner_to_shard[runner_id_a]
shard_b = instance_params.shard_assignments.runner_to_shard[runner_id_b]
shard_c = instance_params.shard_assignments.runner_to_shard[runner_id_c]
assert shard_a.end_layer - shard_a.start_layer == expected_layers[0]
assert shard_b.end_layer - shard_b.start_layer == expected_layers[1]
assert shard_c.end_layer - shard_c.start_layer == expected_layers[2]
shards = [shard_a, shard_b, shard_c]
shards_sorted = sorted(shards, key=lambda s: s.start_layer)
assert shards_sorted[0].start_layer == 0
assert shards_sorted[-1].end_layer == total_layers
def test_get_transition_events_no_change(topology: Topology, instance_params: InstanceParams):
# arrange
instance_id = InstanceId()
current_instances = {
instance_id: instance_params
}
target_instances = {
instance_id: instance_params
}
# act
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 0
def test_get_transition_events_create_instance(topology: Topology, instance_params: InstanceParams):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, InstanceParams] = {}
target_instances: dict[InstanceId, InstanceParams] = {
instance_id: instance_params
}
# act
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 1
assert events[0].event_type == _EventType.InstanceCreated
def test_get_transition_events_delete_instance(topology: Topology, instance_params: InstanceParams):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, InstanceParams] = {
instance_id: instance_params
}
target_instances: dict[InstanceId, InstanceParams] = {}
# act
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 1
assert events[0].event_type == _EventType.InstanceDeleted
assert events[0].instance_id == instance_id

View File

@@ -0,0 +1,173 @@
from typing import Callable
import pytest
from master.utils.placement_utils import (
filter_cycles_by_memory,
get_shard_assignments,
get_smallest_cycles,
)
from shared.topology import Topology
from shared.types.common import NodeId
from shared.types.models import ModelMetadata
from shared.types.topology import Connection, Node
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
# arrange
node1_id = NodeId()
node2_id = NodeId()
node1 = create_node(1000, node1_id)
node2 = create_node(1000, node2_id)
topology.add_node(node1, node1_id)
topology.add_node(node2, node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, 1)
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
# arrange
node1_id = NodeId()
node2_id = NodeId()
node1 = create_node(1000, node1_id)
node2 = create_node(1000, node2_id)
topology.add_node(node1, node1_id)
topology.add_node(node2, node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a, node_a_id)
topology.add_node(node_b, node_b_id)
topology.add_node(node_c, node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, 1500)
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n.node_id for n in filtered_cycles[0]) == {node_a_id, node_b_id, node_c_id}
def test_get_smallest_cycles(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection]):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a, node_a_id)
topology.add_node(node_b, node_b_id)
topology.add_node(node_c, node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize("available_memory,total_layers,expected_layers", [
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7))
])
def test_get_shard_assignments(topology: Topology, create_node: Callable[[int, NodeId | None], Node], create_connection: Callable[[NodeId, NodeId], Connection], available_memory: tuple[int, int, int], total_layers: int, expected_layers: tuple[int, int, int]):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(available_memory[0], node_a_id)
node_b = create_node(available_memory[1], node_b_id)
node_c = create_node(available_memory[2], node_c_id)
topology.add_node(node_a, node_a_id)
topology.add_node(node_b, node_b_id)
topology.add_node(node_c, node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id="test-model",
pretty_name="Test Model",
n_layers=total_layers,
storage_size_kilobytes=1000
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(model_meta, selected_cycle)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert shard_assignments.runner_to_shard[runner_id_c].end_layer - shard_assignments.runner_to_shard[runner_id_c].start_layer == expected_layers[2]
assert shard_assignments.runner_to_shard[runner_id_a].end_layer - shard_assignments.runner_to_shard[runner_id_a].start_layer == expected_layers[0]
assert shard_assignments.runner_to_shard[runner_id_b].end_layer - shard_assignments.runner_to_shard[runner_id_b].start_layer == expected_layers[1]

View File

@@ -19,7 +19,7 @@ def connection() -> Connection:
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0)
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000)
system_profile = SystemPerformanceProfile(flops_fp16=1000)
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[], system=system_profile)
@@ -57,7 +57,7 @@ def test_update_node_profile(topology: Topology, node_profile: NodePerformancePr
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile), node_id=connection.sink_node_id)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_used=0, swap_total=1000, swap_used=0), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
# act
topology.update_node_profile(connection.source_node_id, node_profile=new_node_profile)

View File

@@ -0,0 +1,77 @@
from typing import TypeGuard, cast
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.models import ModelMetadata
from shared.types.profiling import NodePerformanceProfile
from shared.types.topology import Node
from shared.types.worker.common import RunnerId
from shared.types.worker.runners import ShardAssignments
from shared.types.worker.shards import PipelineShardMetadata
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(cycles: list[list[Node]], required_memory: int) -> list[list[Node]]:
filtered_cycles: list[list[Node]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
continue
total_mem = sum(node.node_profile.memory.ram_available for node in cycle)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[Node], cycle))
return filtered_cycles
def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[Node],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
cycle_memory = sum(node.node_profile.memory.ram_available for node in selected_cycle)
total_layers = model_meta.n_layers
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(total_layers * (node.node_profile.memory.ram_available / cycle_memory))
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
world_size=len(selected_cycle),
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner
)
return shard_assignments

View File

@@ -10,11 +10,8 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
CommandId,
)
from shared.types.common import CommandId, NodeId
from shared.types.events import ChunkGenerated
from shared.types.events.chunks import ChunkType, TokenChunk
# Type ignore comment for all protected member access in this test file

View File

@@ -84,6 +84,15 @@ class Topology(TopologyProto):
del self._edge_id_to_rx_id_map[connection]
del self._rx_id_to_node_id_map[rx_idx]
def get_cycles(self) -> list[list[Node]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[Node]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def _is_bridge(self, connection: Connection) -> bool:
edge_idx = self._edge_id_to_rx_id_map[connection]
graph_copy = self._graph.copy().to_undirected()

View File

@@ -19,4 +19,7 @@ class ID(str):
return handler.generate_schema(str)
class NodeId(ID):
pass
pass
class CommandId(ID):
pass

View File

@@ -4,15 +4,10 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter
from shared.openai_compat import FinishReason
from shared.types.common import ID
from shared.types.common import CommandId
from shared.types.models import ModelId
class CommandId(ID):
"""
Newtype around `ID` for command IDs
"""
class ChunkType(str, Enum):
token = "token"
image = "image"

View File

@@ -4,35 +4,36 @@ from typing import Annotated, Callable, Literal, Sequence
from pydantic import BaseModel, Field, TypeAdapter
from shared.types.api import ChatCompletionTaskParams
from shared.types.common import CommandId
from shared.types.events import Event
from shared.types.events.chunks import CommandId
from shared.types.models import ModelMetadata
from shared.types.state import InstanceId, State
# TODO: We need to have a distinction between create instance and spin up instance.
class CommandTypes(str, Enum):
class CommandType(str, Enum):
CHAT_COMPLETION = "CHAT_COMPLETION"
CREATE_INSTANCE = "CREATE_INSTANCE"
DELETE_INSTANCE = "DELETE_INSTANCE"
class _BaseCommand[T: CommandTypes](BaseModel):
class _BaseCommand[T: CommandType](BaseModel):
command_id: CommandId
command_type: T
class ChatCompletionCommand(_BaseCommand[CommandTypes.CHAT_COMPLETION]):
command_type: Literal[CommandTypes.CHAT_COMPLETION] = CommandTypes.CHAT_COMPLETION
class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]):
command_type: Literal[CommandType.CHAT_COMPLETION] = CommandType.CHAT_COMPLETION
request_params: ChatCompletionTaskParams
class CreateInstanceCommand(_BaseCommand[CommandTypes.CREATE_INSTANCE]):
command_type: Literal[CommandTypes.CREATE_INSTANCE] = CommandTypes.CREATE_INSTANCE
model_id: str
class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]):
command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE
model_meta: ModelMetadata
class DeleteInstanceCommand(_BaseCommand[CommandTypes.DELETE_INSTANCE]):
command_type: Literal[CommandTypes.DELETE_INSTANCE] = CommandTypes.DELETE_INSTANCE
class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):
command_type: Literal[CommandType.DELETE_INSTANCE] = CommandType.DELETE_INSTANCE
instance_id: InstanceId

View File

@@ -3,9 +3,9 @@ from pydantic import BaseModel, Field
class MemoryPerformanceProfile(BaseModel):
ram_total: int
ram_used: int
ram_available: int
swap_total: int
swap_used: int
swap_available: int
class SystemPerformanceProfile(BaseModel):

View File

@@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from pydantic import BaseModel, ConfigDict, Field
@@ -12,10 +11,6 @@ from shared.types.worker.instances import BaseInstance
from shared.types.worker.runners import RunnerId, RunnerStatus
class CachePolicy(str, Enum):
KEEP_ALL = "KEEP_ALL"
class State(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
node_status: Mapping[NodeId, NodeStatus] = {}
@@ -25,5 +20,4 @@ class State(BaseModel):
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
topology: Topology = Topology()
history: Sequence[Topology] = []
cache_policy: CachePolicy = CachePolicy.KEEP_ALL
last_event_applied_idx: int = Field(default=0, ge=0)

View File

@@ -63,3 +63,5 @@ class TopologyProto(Protocol):
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: ...
def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None: ...
def get_cycles(self) -> list[list[Node]]: ...

View File

@@ -29,6 +29,9 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
"""
Pipeline parallelism shard meta.
Layers are represented as a half-open interval [start_layer, end_layer),
where start_layer is inclusive and end_layer is exclusive.
"""
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
@@ -44,7 +47,7 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
@property
def is_last_layer(self) -> bool:
return self.end_layer == self.n_layers - 1
return self.end_layer == self.n_layers
def __hash__(self) -> int:
return hash((self.model_meta.model_id, self.start_layer, self.end_layer, self.n_layers))

View File

@@ -25,7 +25,7 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_meta.n_layers - 1,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
)
@@ -39,7 +39,7 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
end_layer=base_shard.n_layers - 1,
end_layer=base_shard.n_layers,
n_layers=base_shard.n_layers,
)

View File

@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Callable
from shared.types.events import CommandId
from shared.types.common import CommandId
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.tasks import ChatCompletionTaskParams, Task
from shared.types.worker.commands_runner import (