feat: Add ResourceGraph, runner types, etc.

This commit is contained in:
Arbion Halili
2025-07-01 13:14:26 +01:00
parent df824e2e87
commit 73ac8969bc
8 changed files with 118 additions and 23 deletions

View File

@@ -0,0 +1,16 @@
from collections.abc import Mapping
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.networking.topology import Topology
from shared.types.profiling.common import NodeProfile
class ResourceGraph(BaseModel): ...
def get_graph_of_compute_resources(
network_topology: Topology,
node_profiles: Mapping[NodeId, NodeProfile],
) -> ResourceGraph: ...

View File

@@ -1,5 +1,5 @@
from typing import Annotated, Any, Generic, Literal, TypeVar, Union, final
from enum import Enum
from typing import Annotated, Any, Generic, Literal, TypeVar, Union, final
from pydantic import AnyHttpUrl, BaseModel, Field, TypeAdapter

View File

@@ -49,7 +49,6 @@ class EdgeMetadata(BaseModel, Generic[TE, TF]): ...
@final
@dataclass
class MLXEdgeContext(EdgeMetadata[AddressingProtocol.IPvAny, ApplicationProtocol.MLX]):
source_ip: IPvAnyAddress
sink_ip: IPvAnyAddress

View File

@@ -1,11 +1,15 @@
from collections.abc import Mapping, Sequence
from queue import Queue
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.events.common import Event, EventTypes
from shared.types.graphs.resource_graph import ResourceGraph
from shared.types.networking.topology import NetworkState
from shared.types.profiling.common import NodeProfile
from shared.types.states.shared import SharedState
from shared.types.worker.instances import InstanceData, InstanceId
class ExternalCommand(BaseModel): ...
@@ -16,3 +20,20 @@ class MasterState(SharedState):
node_profiles: dict[NodeId, NodeProfile]
job_inbox: Queue[ExternalCommand]
job_outbox: Queue[ExternalCommand]
def get_inference_plan(
inbox: Queue[ExternalCommand],
outbox: Queue[ExternalCommand],
resource_graph: ResourceGraph,
current_instances: Mapping[InstanceId, InstanceData],
) -> Mapping[InstanceId, InstanceData]: ...
TransitionEventTypes = EventTypes
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceData],
target_instances: Mapping[InstanceId, InstanceData],
) -> Sequence[Event[TransitionEventTypes]]: ...

View File

@@ -1,10 +1,12 @@
from collections.abc import Mapping
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.worker.common import InstanceId
from shared.types.worker.shards import ShardPlacement
from shared.types.worker.instances import InstanceData
class SharedState(BaseModel):
node_id: NodeId
compute_instances: dict[InstanceId, ShardPlacement]
compute_instances: Mapping[InstanceId, InstanceData]

View File

@@ -1,18 +1,24 @@
from typing import Generic, TypeVar
from collections.abc import Mapping
from pydantic import BaseModel
from shared.types.worker.common import InstanceId
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
from shared.types.worker.shards import ShardPlacement
DownloadStatusT = TypeVar("DownloadStatusT", bound=DownloadStatus)
from shared.types.worker.runners import (
RunnerId,
RunnerPlacement,
RunnerState,
RunnerStateType,
)
class Instance(ShardPlacement):
class InstanceBase(BaseModel):
instance_id: InstanceId
class InstanceDownloadProgress(BaseModel, Generic[DownloadStatusT]):
instance_id: InstanceId
download_progress: BaseDownloadProgress[DownloadStatusT]
class InstanceData(BaseModel):
runner_placements: RunnerPlacement
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
class Instance(InstanceBase):
instance_data: InstanceData

View File

@@ -0,0 +1,61 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.models.common import ModelId
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
from shared.types.worker.shards import Shard, ShardType
class RunnerStateType(str, Enum):
Rejected = "Rejected"
Starting = "Starting"
Downloading = "Downloading"
Running = "Running"
Failed = "Failed"
RunnerStateTypeT = TypeVar("RunnerStateTypeT", bound=RunnerStateType)
class RunnerState(BaseModel, Generic[RunnerStateTypeT]):
runner_state: RunnerStateTypeT
class RejectedRunnerState(RunnerState[RunnerStateType.Rejected]):
runner_state: Literal[RunnerStateType.Rejected]
class StartingRunnerState(RunnerState[RunnerStateType.Starting]):
runner_state: Literal[RunnerStateType.Starting]
class DownloadingRunnerState(RunnerState[RunnerStateType.Downloading]):
runner_state: Literal[RunnerStateType.Downloading]
download_progress: BaseDownloadProgress[DownloadStatus]
class RunningRunnerState(RunnerState[RunnerStateType.Running]):
runner_state: Literal[RunnerStateType.Running]
class FailedRunnerState(RunnerState[RunnerStateType.Failed]):
runner_state: Literal[RunnerStateType.Failed]
error_message: str | None = None
class RunnerData(BaseModel):
runner_id: RunnerId
runner_state: RunnerState[RunnerStateType] = RunnerState(
runner_state=RunnerStateType.Starting
)
class RunnerPlacement(BaseModel):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, Shard[ShardType]]
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]

View File

@@ -3,10 +3,6 @@ from typing import Generic, TypeVar
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.models.common import ModelId
from shared.types.worker.common import RunnerId
class ShardType(str, Enum):
PipelineParallel = "PipelineParallel"
@@ -21,9 +17,3 @@ class ShardData(BaseModel, Generic[ShardTypeT]):
class Shard(BaseModel, Generic[ShardTypeT]):
shard_data: ShardData[ShardTypeT]
runner_id: RunnerId
class ShardPlacement(BaseModel):
model_id: ModelId
shard_assignments: dict[NodeId, Shard[ShardType]]