mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
feat: Add ResourceGraph, runner types, etc.
This commit is contained in:
16
shared/types/graphs/resource_graph.py
Normal file
16
shared/types/graphs/resource_graph.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.networking.topology import Topology
|
||||
from shared.types.profiling.common import NodeProfile
|
||||
|
||||
|
||||
class ResourceGraph(BaseModel): ...
|
||||
|
||||
|
||||
def get_graph_of_compute_resources(
|
||||
network_topology: Topology,
|
||||
node_profiles: Mapping[NodeId, NodeProfile],
|
||||
) -> ResourceGraph: ...
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Annotated, Any, Generic, Literal, TypeVar, Union, final
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Generic, Literal, TypeVar, Union, final
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field, TypeAdapter
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ class EdgeMetadata(BaseModel, Generic[TE, TF]): ...
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MLXEdgeContext(EdgeMetadata[AddressingProtocol.IPvAny, ApplicationProtocol.MLX]):
|
||||
source_ip: IPvAnyAddress
|
||||
sink_ip: IPvAnyAddress
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from queue import Queue
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import Event, EventTypes
|
||||
from shared.types.graphs.resource_graph import ResourceGraph
|
||||
from shared.types.networking.topology import NetworkState
|
||||
from shared.types.profiling.common import NodeProfile
|
||||
from shared.types.states.shared import SharedState
|
||||
from shared.types.worker.instances import InstanceData, InstanceId
|
||||
|
||||
|
||||
class ExternalCommand(BaseModel): ...
|
||||
@@ -16,3 +20,20 @@ class MasterState(SharedState):
|
||||
node_profiles: dict[NodeId, NodeProfile]
|
||||
job_inbox: Queue[ExternalCommand]
|
||||
job_outbox: Queue[ExternalCommand]
|
||||
|
||||
|
||||
def get_inference_plan(
|
||||
inbox: Queue[ExternalCommand],
|
||||
outbox: Queue[ExternalCommand],
|
||||
resource_graph: ResourceGraph,
|
||||
current_instances: Mapping[InstanceId, InstanceData],
|
||||
) -> Mapping[InstanceId, InstanceData]: ...
|
||||
|
||||
|
||||
TransitionEventTypes = EventTypes
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceData],
|
||||
target_instances: Mapping[InstanceId, InstanceData],
|
||||
) -> Sequence[Event[TransitionEventTypes]]: ...
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.shards import ShardPlacement
|
||||
from shared.types.worker.instances import InstanceData
|
||||
|
||||
|
||||
class SharedState(BaseModel):
|
||||
node_id: NodeId
|
||||
compute_instances: dict[InstanceId, ShardPlacement]
|
||||
compute_instances: Mapping[InstanceId, InstanceData]
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
from typing import Generic, TypeVar
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
|
||||
from shared.types.worker.shards import ShardPlacement
|
||||
|
||||
DownloadStatusT = TypeVar("DownloadStatusT", bound=DownloadStatus)
|
||||
from shared.types.worker.runners import (
|
||||
RunnerId,
|
||||
RunnerPlacement,
|
||||
RunnerState,
|
||||
RunnerStateType,
|
||||
)
|
||||
|
||||
|
||||
class Instance(ShardPlacement):
|
||||
class InstanceBase(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDownloadProgress(BaseModel, Generic[DownloadStatusT]):
|
||||
instance_id: InstanceId
|
||||
download_progress: BaseDownloadProgress[DownloadStatusT]
|
||||
class InstanceData(BaseModel):
|
||||
runner_placements: RunnerPlacement
|
||||
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
|
||||
|
||||
|
||||
class Instance(InstanceBase):
|
||||
instance_data: InstanceData
|
||||
|
||||
61
shared/types/worker/runners.py
Normal file
61
shared/types/worker/runners.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
|
||||
from shared.types.worker.shards import Shard, ShardType
|
||||
|
||||
|
||||
class RunnerStateType(str, Enum):
|
||||
Rejected = "Rejected"
|
||||
Starting = "Starting"
|
||||
Downloading = "Downloading"
|
||||
Running = "Running"
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
RunnerStateTypeT = TypeVar("RunnerStateTypeT", bound=RunnerStateType)
|
||||
|
||||
|
||||
class RunnerState(BaseModel, Generic[RunnerStateTypeT]):
|
||||
runner_state: RunnerStateTypeT
|
||||
|
||||
|
||||
class RejectedRunnerState(RunnerState[RunnerStateType.Rejected]):
|
||||
runner_state: Literal[RunnerStateType.Rejected]
|
||||
|
||||
|
||||
class StartingRunnerState(RunnerState[RunnerStateType.Starting]):
|
||||
runner_state: Literal[RunnerStateType.Starting]
|
||||
|
||||
|
||||
class DownloadingRunnerState(RunnerState[RunnerStateType.Downloading]):
|
||||
runner_state: Literal[RunnerStateType.Downloading]
|
||||
download_progress: BaseDownloadProgress[DownloadStatus]
|
||||
|
||||
|
||||
class RunningRunnerState(RunnerState[RunnerStateType.Running]):
|
||||
runner_state: Literal[RunnerStateType.Running]
|
||||
|
||||
|
||||
class FailedRunnerState(RunnerState[RunnerStateType.Failed]):
|
||||
runner_state: Literal[RunnerStateType.Failed]
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class RunnerData(BaseModel):
|
||||
runner_id: RunnerId
|
||||
runner_state: RunnerState[RunnerStateType] = RunnerState(
|
||||
runner_state=RunnerStateType.Starting
|
||||
)
|
||||
|
||||
|
||||
class RunnerPlacement(BaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, Shard[ShardType]]
|
||||
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]
|
||||
@@ -3,10 +3,6 @@ from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.worker.common import RunnerId
|
||||
|
||||
|
||||
class ShardType(str, Enum):
|
||||
PipelineParallel = "PipelineParallel"
|
||||
@@ -21,9 +17,3 @@ class ShardData(BaseModel, Generic[ShardTypeT]):
|
||||
|
||||
class Shard(BaseModel, Generic[ShardTypeT]):
|
||||
shard_data: ShardData[ShardTypeT]
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ShardPlacement(BaseModel):
|
||||
model_id: ModelId
|
||||
shard_assignments: dict[NodeId, Shard[ShardType]]
|
||||
|
||||
Reference in New Issue
Block a user