mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
feat: Update Interfaces
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
)
|
||||
from uuid import UUID
|
||||
@@ -19,41 +20,54 @@ EventId = type("EventId", (UUID,), {})
|
||||
EventIdParser: TypeAdapter[EventId] = TypeAdapter(_EventId)
|
||||
|
||||
|
||||
class EventTypes(str, Enum):
|
||||
ChatCompletionsRequestStarted = "ChatCompletionsRequestStarted"
|
||||
ChatCompletionsRequestCompleted = "ChatCompletionsRequestCompleted"
|
||||
ChatCompletionsRequestFailed = "ChatCompletionsRequestFailed"
|
||||
InferenceSagaStarted = "InferenceSagaStarted"
|
||||
InferencePrepareStarted = "InferencePrepareStarted"
|
||||
InferencePrepareCompleted = "InferencePrepareCompleted"
|
||||
InferenceTriggerStarted = "InferenceTriggerStarted"
|
||||
InferenceTriggerCompleted = "InferenceTriggerCompleted"
|
||||
InferenceCompleted = "InferenceCompleted"
|
||||
InferenceSagaCompleted = "InferenceSagaCompleted"
|
||||
InstanceSetupSagaStarted = "InstanceSetupSagaStarted"
|
||||
InstanceSetupSagaCompleted = "InstanceSetupSagaCompleted"
|
||||
InstanceSetupSagaFailed = "InstanceSetupSagaFailed"
|
||||
ShardAssigned = "ShardAssigned"
|
||||
ShardAssignFailed = "ShardAssignFailed"
|
||||
ShardUnassigned = "ShardUnassigned"
|
||||
ShardUnassignFailed = "ShardUnassignFailed"
|
||||
ShardKilled = "ShardKilled"
|
||||
ShardDied = "ShardDied"
|
||||
ShardSpawned = "ShardSpawned"
|
||||
ShardSpawnedFailed = "ShardSpawnedFailed"
|
||||
ShardDespawned = "ShardDespawned"
|
||||
NodeConnected = "NodeConnected"
|
||||
NodeConnectionProfiled = "NodeConnectionProfiled"
|
||||
NodeDisconnected = "NodeDisconnected"
|
||||
NodeStarted = "NodeStarted"
|
||||
DeviceRegistered = "DeviceRegistered"
|
||||
DeviceProfiled = "DeviceProfiled"
|
||||
TokenGenerated = "TokenGenerated"
|
||||
RepoProgressEvent = "RepoProgressEvent"
|
||||
TimerScheduled = "TimerScheduled"
|
||||
class MLXEventTypes(str, Enum):
|
||||
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
|
||||
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
|
||||
|
||||
|
||||
class TaskEventTypes(str, Enum):
|
||||
TaskCreated = "TaskCreated"
|
||||
TaskUpdated = "TaskUpdated"
|
||||
TaskDeleted = "TaskDeleted"
|
||||
|
||||
|
||||
class StreamingEventTypes(str, Enum):
|
||||
ChunkGenerated = "ChunkGenerated"
|
||||
|
||||
|
||||
class InstanceEventTypes(str, Enum):
|
||||
InstanceCreated = "InstanceCreated"
|
||||
InstanceDeleted = "InstanceDeleted"
|
||||
InstanceReplacedAtomically = "InstanceReplacedAtomically"
|
||||
InstanceRunnerStateUpdated = "InstanceRunnerStateUpdated"
|
||||
|
||||
|
||||
class NodeEventTypes(str, Enum):
|
||||
NodeStateUpdated = "NodeStateUpdated"
|
||||
NodeProfileUpdated = "NodeProfileUpdated"
|
||||
|
||||
|
||||
class EdgeEventTypes(str, Enum):
|
||||
EdgeCreated = "EdgeCreated"
|
||||
EdgeUpdated = "EdgeUpdated"
|
||||
EdgeDeleted = "EdgeDeleted"
|
||||
|
||||
|
||||
class TimerEventTypes(str, Enum):
|
||||
TimerCreated = "TimerCreated"
|
||||
TimerFired = "TimerFired"
|
||||
|
||||
|
||||
EventTypes = Union[
|
||||
TaskEventTypes,
|
||||
StreamingEventTypes,
|
||||
InstanceEventTypes,
|
||||
NodeEventTypes,
|
||||
EdgeEventTypes,
|
||||
TimerEventTypes,
|
||||
MLXEventTypes,
|
||||
]
|
||||
|
||||
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)
|
||||
TEventType = TypeVar("TEventType", bound=EventTypes, covariant=True)
|
||||
|
||||
@@ -73,7 +87,7 @@ class State(BaseModel, Generic[EventTypeT]):
|
||||
sequence_number: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
AnnotatedEventType = Annotated[EventTypes, Field(discriminator="event_type")]
|
||||
AnnotatedEventType = Annotated[Event[EventTypes], Field(discriminator="event_type")]
|
||||
EventTypeParser: TypeAdapter[AnnotatedEventType] = TypeAdapter(AnnotatedEventType)
|
||||
|
||||
Applicator = Callable[[State[EventTypeT], Event[TEventType]], State[EventTypeT]]
|
||||
@@ -131,8 +145,8 @@ class CommandTypes(str, Enum):
|
||||
Delete = "Delete"
|
||||
|
||||
|
||||
CommandTypeT = TypeVar("CommandTypeT", bound=EventTypes)
|
||||
TCommandType = TypeVar("TCommandType", bound=EventTypes, covariant=True)
|
||||
CommandTypeT = TypeVar("CommandTypeT", bound=CommandTypes)
|
||||
TCommandType = TypeVar("TCommandType", bound=CommandTypes, covariant=True)
|
||||
|
||||
|
||||
class Command(BaseModel, Generic[TEventType, TCommandType]):
|
||||
|
||||
@@ -1,15 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, UuidVersion
|
||||
|
||||
from shared.openai import FinishReason, chat
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import Event, EventTypes
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.events.common import (
|
||||
Event,
|
||||
InstanceEventTypes,
|
||||
MLXEventTypes,
|
||||
NodeEventTypes,
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TimerEventTypes,
|
||||
)
|
||||
from shared.types.profiling.common import NodeProfile
|
||||
from shared.types.tasks.common import (
|
||||
TaskData,
|
||||
TaskId,
|
||||
TaskStatusType,
|
||||
TaskType,
|
||||
TaskUpdate,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, NodeState
|
||||
from shared.types.worker.instances import InstanceData
|
||||
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
|
||||
|
||||
_RequestId = Annotated[UUID, UuidVersion(4)]
|
||||
RequestId = type("RequestId", (UUID,), {})
|
||||
@@ -20,305 +36,107 @@ TimerId = type("TimerId", (UUID,), {})
|
||||
TimerIdParser: TypeAdapter[TimerId] = TypeAdapter(_TimerId)
|
||||
|
||||
|
||||
class Shard(BaseModel):
|
||||
# TODO: this has changed
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class InstanceComputePlan(BaseModel):
|
||||
# TODO: this has changed
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class Timer(BaseModel):
|
||||
class TimerData(BaseModel):
|
||||
timer_id: TimerId
|
||||
|
||||
|
||||
# Chat completions ----------------------------------------------------------------
|
||||
class ChatCompletionsRequestStarted(Event[EventTypes.ChatCompletionsRequestStarted]):
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestStarted] = (
|
||||
EventTypes.ChatCompletionsRequestStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
class TaskCreated(Event[TaskEventTypes.TaskCreated]):
|
||||
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
|
||||
task_id: TaskId
|
||||
task_data: TaskData[TaskType]
|
||||
task_state: TaskUpdate[Literal[TaskStatusType.Pending]]
|
||||
on_instance: InstanceId
|
||||
|
||||
|
||||
class ChatCompletionsRequestCompleted(
|
||||
Event[EventTypes.ChatCompletionsRequestCompleted]
|
||||
):
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestCompleted] = (
|
||||
EventTypes.ChatCompletionsRequestCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
class TaskUpdated(Event[TaskEventTypes.TaskUpdated]):
|
||||
event_type: Literal[TaskEventTypes.TaskUpdated] = TaskEventTypes.TaskUpdated
|
||||
task_id: TaskId
|
||||
update_data: TaskUpdate[TaskStatusType]
|
||||
|
||||
|
||||
class ChatCompletionsRequestFailed(Event[EventTypes.ChatCompletionsRequestFailed]):
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestFailed] = (
|
||||
EventTypes.ChatCompletionsRequestFailed
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
error_message: str
|
||||
class TaskDeleted(Event[TaskEventTypes.TaskDeleted]):
|
||||
event_type: Literal[TaskEventTypes.TaskDeleted] = TaskEventTypes.TaskDeleted
|
||||
task_id: TaskId
|
||||
|
||||
|
||||
# Inference saga ------------------------------------------------------------------
|
||||
class InferenceSagaStarted(Event[EventTypes.InferenceSagaStarted]):
|
||||
event_type: Literal[EventTypes.InferenceSagaStarted] = (
|
||||
EventTypes.InferenceSagaStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferencePrepareStarted(Event[EventTypes.InferencePrepareStarted]):
|
||||
event_type: Literal[EventTypes.InferencePrepareStarted] = (
|
||||
EventTypes.InferencePrepareStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
shard: Shard # replaces model_id, rank, start_layer, end_layer
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferencePrepareCompleted(Event[EventTypes.InferencePrepareCompleted]):
|
||||
event_type: Literal[EventTypes.InferencePrepareCompleted] = (
|
||||
EventTypes.InferencePrepareCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
shard: Shard
|
||||
|
||||
|
||||
class InferenceTriggerStarted(Event[EventTypes.InferenceTriggerStarted]):
|
||||
event_type: Literal[EventTypes.InferenceTriggerStarted] = (
|
||||
EventTypes.InferenceTriggerStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
shard: Shard
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferenceTriggerCompleted(Event[EventTypes.InferenceTriggerCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceTriggerCompleted] = (
|
||||
EventTypes.InferenceTriggerCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
shard: Shard
|
||||
|
||||
|
||||
class InferenceCompleted(Event[EventTypes.InferenceCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceCompleted] = EventTypes.InferenceCompleted
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class InferenceSagaCompleted(Event[EventTypes.InferenceSagaCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceSagaCompleted] = (
|
||||
EventTypes.InferenceSagaCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
# Instance setup saga ------------------------------------------------------------
|
||||
class InstanceSetupSagaStarted(Event[EventTypes.InstanceSetupSagaStarted]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaStarted] = (
|
||||
EventTypes.InstanceSetupSagaStarted
|
||||
)
|
||||
instance_id: str
|
||||
model_id: ModelId
|
||||
plan: InstanceComputePlan
|
||||
|
||||
|
||||
class InstanceSetupSagaCompleted(Event[EventTypes.InstanceSetupSagaCompleted]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaCompleted] = (
|
||||
EventTypes.InstanceSetupSagaCompleted
|
||||
class InstanceCreated(Event[InstanceEventTypes.InstanceCreated]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceCreated] = (
|
||||
InstanceEventTypes.InstanceCreated
|
||||
)
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
instance_data: InstanceData
|
||||
|
||||
|
||||
class InstanceSetupSagaFailed(Event[EventTypes.InstanceSetupSagaFailed]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaFailed] = (
|
||||
EventTypes.InstanceSetupSagaFailed
|
||||
class InstanceDeleted(Event[InstanceEventTypes.InstanceDeleted]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceDeleted] = (
|
||||
InstanceEventTypes.InstanceDeleted
|
||||
)
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
reason: str
|
||||
|
||||
|
||||
# Shard lifecycle -----------------------------------------------------------------
|
||||
class ShardAssigned(Event[EventTypes.ShardAssigned]):
|
||||
event_type: Literal[EventTypes.ShardAssigned] = EventTypes.ShardAssigned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardAssignFailed(Event[EventTypes.ShardAssignFailed]):
|
||||
event_type: Literal[EventTypes.ShardAssignFailed] = EventTypes.ShardAssignFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
reason: str # e.g. "not enough memory"
|
||||
|
||||
|
||||
class ShardUnassigned(Event[EventTypes.ShardUnassigned]):
|
||||
event_type: Literal[EventTypes.ShardUnassigned] = EventTypes.ShardUnassigned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
reason: str # e.g. "instance did not receive request for 5 mins"
|
||||
|
||||
|
||||
class ShardUnassignFailed(Event[EventTypes.ShardUnassignFailed]):
|
||||
event_type: Literal[EventTypes.ShardUnassignFailed] = EventTypes.ShardUnassignFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
reason: str # e.g. "process refused to quit"
|
||||
|
||||
|
||||
class ShardKilled(Event[EventTypes.ShardKilled]):
|
||||
event_type: Literal[EventTypes.ShardKilled] = EventTypes.ShardKilled
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardDied(Event[EventTypes.ShardDied]):
|
||||
event_type: Literal[EventTypes.ShardDied] = EventTypes.ShardDied
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
error_type: str
|
||||
error_message: str
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
class ShardSpawned(Event[EventTypes.ShardSpawned]):
|
||||
event_type: Literal[EventTypes.ShardSpawned] = EventTypes.ShardSpawned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardSpawnedFailed(Event[EventTypes.ShardSpawnedFailed]):
|
||||
event_type: Literal[EventTypes.ShardSpawnedFailed] = EventTypes.ShardSpawnedFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
reason: str # e.g. "not enough memory"
|
||||
|
||||
|
||||
class ShardDespawned(Event[EventTypes.ShardDespawned]):
|
||||
event_type: Literal[EventTypes.ShardDespawned] = EventTypes.ShardDespawned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
# Node connectivity --------------------------------------------------------------
|
||||
class NodeConnected(Event[EventTypes.NodeConnected]):
|
||||
event_type: Literal[EventTypes.NodeConnected] = EventTypes.NodeConnected
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
multiaddr: str
|
||||
remote_multiaddr: str
|
||||
ip: str
|
||||
remote_ip: str
|
||||
|
||||
|
||||
class NodeConnectionProfiled(Event[EventTypes.NodeConnectionProfiled]):
|
||||
event_type: Literal[EventTypes.NodeConnectionProfiled] = (
|
||||
EventTypes.NodeConnectionProfiled
|
||||
class InstanceRunnerStateUpdated(Event[InstanceEventTypes.InstanceRunnerStateUpdated]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceRunnerStateUpdated] = (
|
||||
InstanceEventTypes.InstanceRunnerStateUpdated
|
||||
)
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
latency_ms: int
|
||||
bandwidth_bytes_per_second: int
|
||||
|
||||
|
||||
class NodeDisconnected(Event[EventTypes.NodeDisconnected]):
|
||||
event_type: Literal[EventTypes.NodeDisconnected] = EventTypes.NodeDisconnected
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
|
||||
|
||||
class NodeStarted(Event[EventTypes.NodeStarted]):
|
||||
event_type: Literal[EventTypes.NodeStarted] = EventTypes.NodeStarted
|
||||
|
||||
|
||||
# Device metrics -----------------------------------------------------------------
|
||||
class DeviceRegistered(Event[EventTypes.DeviceRegistered]):
|
||||
event_type: Literal[EventTypes.DeviceRegistered] = EventTypes.DeviceRegistered
|
||||
device_id: str
|
||||
device_model: str
|
||||
device_type: str
|
||||
total_memory_bytes: int
|
||||
available_memory_bytes: int
|
||||
|
||||
|
||||
class DeviceProfiled(Event[EventTypes.DeviceProfiled]):
|
||||
event_type: Literal[EventTypes.DeviceProfiled] = EventTypes.DeviceProfiled
|
||||
device_id: str
|
||||
total_memory_bytes: int
|
||||
available_memory_bytes: int
|
||||
total_flops_fp16: int
|
||||
|
||||
|
||||
# Token streaming ----------------------------------------------------------------
|
||||
class TokenGenerated(Event[EventTypes.TokenGenerated]):
|
||||
# TODO: replace with matt chunk code
|
||||
event_type: Literal[EventTypes.TokenGenerated] = EventTypes.TokenGenerated
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
hosts: List[str]
|
||||
token: int
|
||||
text: str
|
||||
finish_reason: FinishReason
|
||||
state_update: Tuple[RunnerId, RunnerState[RunnerStateType]]
|
||||
|
||||
|
||||
# Repo download progress ----------------------------------------------------------
|
||||
class RepoProgressEvent(Event[EventTypes.RepoProgressEvent]):
|
||||
event_type: Literal[EventTypes.RepoProgressEvent] = EventTypes.RepoProgressEvent
|
||||
repo_id: str
|
||||
downloaded_bytes: int
|
||||
total_bytes: int
|
||||
speed_bytes_per_second: int
|
||||
class InstanceReplacedAtomically(Event[InstanceEventTypes.InstanceReplacedAtomically]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceReplacedAtomically] = (
|
||||
InstanceEventTypes.InstanceReplacedAtomically
|
||||
)
|
||||
old_instance_id: InstanceId
|
||||
new_instance_id: InstanceId
|
||||
new_instance_data: InstanceData
|
||||
|
||||
|
||||
# Timers -------------------------------------------------------------------------
|
||||
class TimerScheduled(Event[EventTypes.TimerScheduled]):
|
||||
event_type: Literal[EventTypes.TimerScheduled] = EventTypes.TimerScheduled
|
||||
timer: Timer
|
||||
class MLXInferenceSagaPrepare(Event[MLXEventTypes.MLXInferenceSagaPrepare]):
|
||||
event_type: Literal[MLXEventTypes.MLXInferenceSagaPrepare] = (
|
||||
MLXEventTypes.MLXInferenceSagaPrepare
|
||||
)
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TimerFired(Event[EventTypes.TimerFired]):
|
||||
event_type: Literal[EventTypes.TimerFired] = EventTypes.TimerFired
|
||||
timer: Timer
|
||||
class MLXInferenceSagaStartPrepare(Event[MLXEventTypes.MLXInferenceSagaStartPrepare]):
|
||||
event_type: Literal[MLXEventTypes.MLXInferenceSagaStartPrepare] = (
|
||||
MLXEventTypes.MLXInferenceSagaStartPrepare
|
||||
)
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class NodeProfileUpdated(Event[NodeEventTypes.NodeProfileUpdated]):
|
||||
event_type: Literal[NodeEventTypes.NodeProfileUpdated] = (
|
||||
NodeEventTypes.NodeProfileUpdated
|
||||
)
|
||||
node_id: NodeId
|
||||
node_profile: NodeProfile
|
||||
|
||||
|
||||
class NodeStateUpdated(Event[NodeEventTypes.NodeStateUpdated]):
|
||||
event_type: Literal[NodeEventTypes.NodeStateUpdated] = (
|
||||
NodeEventTypes.NodeStateUpdated
|
||||
)
|
||||
node_id: NodeId
|
||||
node_state: NodeState
|
||||
|
||||
|
||||
class ChunkGenerated(Event[StreamingEventTypes.ChunkGenerated]):
|
||||
event_type: Literal[StreamingEventTypes.ChunkGenerated] = (
|
||||
StreamingEventTypes.ChunkGenerated
|
||||
)
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
chunk: Any
|
||||
|
||||
|
||||
class TimerScheduled(Event[TimerEventTypes.TimerCreated]):
|
||||
event_type: Literal[TimerEventTypes.TimerCreated] = TimerEventTypes.TimerCreated
|
||||
timer_data: TimerData
|
||||
|
||||
|
||||
class TimerFired(Event[TimerEventTypes.TimerFired]):
|
||||
event_type: Literal[TimerEventTypes.TimerFired] = TimerEventTypes.TimerFired
|
||||
timer_data: TimerData
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.networking.topology import Topology
|
||||
from shared.types.profiling.common import NodeProfile
|
||||
from shared.types.worker.common import NodeState
|
||||
|
||||
|
||||
class ResourceGraph(BaseModel): ...
|
||||
@@ -12,5 +13,6 @@ class ResourceGraph(BaseModel): ...
|
||||
|
||||
def get_graph_of_compute_resources(
|
||||
network_topology: Topology,
|
||||
node_states: Mapping[NodeId, NodeState],
|
||||
node_profiles: Mapping[NodeId, NodeProfile],
|
||||
) -> ResourceGraph: ...
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, NamedTuple, TypeVar, final
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, IPvAnyAddress, TypeAdapter
|
||||
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, TypeAdapter
|
||||
from pydantic.types import UuidVersion
|
||||
|
||||
from shared.types.common import NodeId
|
||||
@@ -13,13 +13,6 @@ EdgeId = type("EdgeId", (UUID,), {})
|
||||
EdgeIdParser: TypeAdapter[EdgeId] = TypeAdapter(_EdgeId)
|
||||
|
||||
|
||||
@final
|
||||
class EdgeDataTransferRate(BaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
|
||||
class AddressingProtocol(str, Enum):
|
||||
IPvAny = "IPvAny"
|
||||
|
||||
@@ -28,14 +21,24 @@ class ApplicationProtocol(str, Enum):
|
||||
MLX = "MLX"
|
||||
|
||||
|
||||
TE = TypeVar("TE", bound=AddressingProtocol)
|
||||
TF = TypeVar("TF", bound=ApplicationProtocol)
|
||||
AdP = TypeVar("AdP", bound=AddressingProtocol)
|
||||
ApP = TypeVar("ApP", bound=ApplicationProtocol)
|
||||
|
||||
|
||||
@final
|
||||
class EdgeType(BaseModel, Generic[TE, TF]):
|
||||
addressing_protocol: TE
|
||||
application_protocol: TF
|
||||
class EdgeDataTransferRate(BaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
|
||||
class EdgeMetadata(BaseModel, Generic[AdP, ApP]): ...
|
||||
|
||||
|
||||
@final
|
||||
class EdgeType(BaseModel, Generic[AdP, ApP]):
|
||||
addressing_protocol: AdP
|
||||
application_protocol: ApP
|
||||
|
||||
|
||||
@final
|
||||
@@ -44,41 +47,63 @@ class EdgeDirection(NamedTuple):
|
||||
sink: NodeId
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeMetadata(BaseModel, Generic[TE, TF]): ...
|
||||
|
||||
|
||||
@final
|
||||
class MLXEdgeContext(EdgeMetadata[AddressingProtocol.IPvAny, ApplicationProtocol.MLX]):
|
||||
source_ip: IPvAnyAddress
|
||||
sink_ip: IPvAnyAddress
|
||||
|
||||
|
||||
@final
|
||||
class EdgeInfo(BaseModel, Generic[TE, TF]):
|
||||
edge_type: EdgeType[TE, TF]
|
||||
class EdgeDataType(str, Enum):
|
||||
DISCOVERED = "discovered"
|
||||
PROFILED = "profiled"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
EdgeDataTypeT = TypeVar("EdgeDataTypeT", bound=EdgeDataType)
|
||||
|
||||
|
||||
class EdgeData(BaseModel, Generic[EdgeDataTypeT]):
|
||||
edge_data_type: EdgeDataTypeT
|
||||
|
||||
|
||||
class EdgeProfile(EdgeData[EdgeDataType.PROFILED]):
|
||||
edge_data_transfer_rate: EdgeDataTransferRate
|
||||
edge_metadata: EdgeMetadata[TE, TF]
|
||||
|
||||
|
||||
@final
|
||||
class DirectedEdge(BaseModel, Generic[TE, TF]):
|
||||
def validate_mapping(
|
||||
edge_data: Mapping[EdgeDataType, EdgeData[EdgeDataType]],
|
||||
) -> Mapping[EdgeDataType, EdgeData[EdgeDataType]]:
|
||||
"""Validates that each EdgeData value has an edge_data_type matching its key."""
|
||||
for key, value in edge_data.items():
|
||||
if key != value.edge_data_type:
|
||||
raise ValueError(
|
||||
f"Edge Data Type Mismatch: key {key} != value {value.edge_data_type}"
|
||||
)
|
||||
return edge_data
|
||||
|
||||
|
||||
class Edge(BaseModel, Generic[AdP, ApP, EdgeDataTypeT]):
|
||||
edge_type: EdgeType[AdP, ApP]
|
||||
edge_direction: EdgeDirection
|
||||
edge_identifier: EdgeId
|
||||
edge_info: EdgeInfo[TE, TF]
|
||||
edge_data: Annotated[
|
||||
Mapping[EdgeDataType, EdgeData[EdgeDataType]], AfterValidator(validate_mapping)
|
||||
]
|
||||
edge_metadata: EdgeMetadata[AdP, ApP]
|
||||
|
||||
|
||||
"""
|
||||
an_edge: DirectedEdge[Literal[AddressingProtocol.IPvAny], Literal[ApplicationProtocol.MLX]] = DirectedEdge(
|
||||
edge_identifier=UUID(),
|
||||
edge_direction=EdgeDirection(source=NodeId("1"), sink=NodeId("2")),
|
||||
edge_info=EdgeInfo(
|
||||
an_edge: UniqueEdge[Literal[AddressingProtocol.IPvAny], Literal[ApplicationProtocol.MLX]] = UniqueEdge(
|
||||
edge_identifier=EdgeId(UUID().hex),
|
||||
edge_info=ProfiledEdge(
|
||||
edge_direction=EdgeDirection(source=NodeId("1"), sink=NodeId("2")),
|
||||
edge_type=EdgeType(
|
||||
addressing_protocol=AddressingProtocol.ipv4,
|
||||
application_protocol=ApplicationProtocol.mlx
|
||||
addressing_protocol=AddressingProtocol.IPvAny,
|
||||
application_protocol=ApplicationProtocol.MLX
|
||||
),
|
||||
edge_data_transfer_rate=EdgeDataTransferRate(throughput=1000, latency=0.1, jitter=0.01),
|
||||
edge_metadata=MLXEdgeContext(source_ip=IpV4Addr("192.168.1.1"), sink_ip=IpV4Addr("192.168.1.2"))
|
||||
edge_data=EdgeData(
|
||||
edge_data_transfer_rate=EdgeDataTransferRate(throughput=1000, latency=0.1, jitter=0.01)
|
||||
),
|
||||
edge_metadata=MLXEdgeContext(source_ip=IPv4Address("192.168.1.1"), sink_ip=IPv4Address("192.168.1.2"))
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -6,9 +6,9 @@ from shared.types.common import NodeId
|
||||
from shared.types.networking.edges import (
|
||||
AddressingProtocol,
|
||||
ApplicationProtocol,
|
||||
EdgeDirection,
|
||||
Edge,
|
||||
EdgeDataType,
|
||||
EdgeId,
|
||||
EdgeInfo,
|
||||
)
|
||||
|
||||
TopicName = NewType("TopicName", str)
|
||||
@@ -21,7 +21,8 @@ class WrappedMessage(BaseModel):
|
||||
|
||||
PubSubMessageHandler = Callable[[TopicName, WrappedMessage], None]
|
||||
NodeConnectedHandler = Callable[
|
||||
[EdgeId, EdgeDirection, EdgeInfo[AddressingProtocol, ApplicationProtocol]], None
|
||||
[EdgeId, Edge[AddressingProtocol, ApplicationProtocol, EdgeDataType.DISCOVERED]],
|
||||
None,
|
||||
]
|
||||
NodeDisconnectedHandler = Callable[[EdgeId], None]
|
||||
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.networking.edges import (
|
||||
AddressingProtocol,
|
||||
ApplicationProtocol,
|
||||
EdgeDirection,
|
||||
Edge,
|
||||
EdgeDataType,
|
||||
EdgeId,
|
||||
EdgeInfo,
|
||||
)
|
||||
|
||||
|
||||
class Topology(BaseModel):
|
||||
edges: dict[
|
||||
EdgeId, tuple[EdgeDirection, EdgeInfo[AddressingProtocol, ApplicationProtocol]]
|
||||
edges: Mapping[
|
||||
EdgeId,
|
||||
Edge[AddressingProtocol, ApplicationProtocol, Literal[EdgeDataType.DISCOVERED]],
|
||||
]
|
||||
|
||||
|
||||
class EdgeMap(BaseModel):
|
||||
edges: Mapping[EdgeId, Edge[AddressingProtocol, ApplicationProtocol, EdgeDataType]]
|
||||
|
||||
|
||||
class NetworkState(BaseModel):
|
||||
topology: Topology
|
||||
history: Sequence[Topology]
|
||||
|
||||
@@ -9,6 +9,7 @@ from shared.types.graphs.resource_graph import ResourceGraph
|
||||
from shared.types.networking.topology import NetworkState
|
||||
from shared.types.profiling.common import NodeProfile
|
||||
from shared.types.states.shared import SharedState
|
||||
from shared.types.worker.common import NodeState
|
||||
from shared.types.worker.instances import InstanceData, InstanceId
|
||||
|
||||
|
||||
@@ -17,7 +18,8 @@ class ExternalCommand(BaseModel): ...
|
||||
|
||||
class MasterState(SharedState):
|
||||
network_state: NetworkState
|
||||
node_profiles: dict[NodeId, NodeProfile]
|
||||
node_profiles: Mapping[NodeId, NodeProfile]
|
||||
node_states: Mapping[NodeId, NodeState]
|
||||
job_inbox: Queue[ExternalCommand]
|
||||
job_outbox: Queue[ExternalCommand]
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.tasks.common import Task, TaskId, TaskType
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import InstanceData
|
||||
|
||||
@@ -10,3 +11,4 @@ from shared.types.worker.instances import InstanceData
|
||||
class SharedState(BaseModel):
|
||||
node_id: NodeId
|
||||
compute_instances: Mapping[InstanceId, InstanceData]
|
||||
compute_tasks: dict[TaskId, Task[TaskType]]
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Tuple
|
||||
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.states.shared import SharedState
|
||||
from shared.types.tasks.common import Task, TaskId, TaskType
|
||||
from shared.types.worker.common import NodeState
|
||||
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
|
||||
from shared.types.worker.shards import ShardData, ShardType
|
||||
|
||||
|
||||
class WorkerState(SharedState):
|
||||
download_state: dict[
|
||||
node_state: NodeState
|
||||
download_state: Mapping[
|
||||
Tuple[ModelId, ShardData[ShardType]], BaseDownloadProgress[DownloadStatus]
|
||||
]
|
||||
compute_tasks: dict[TaskId, Task[TaskType]]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Generic, Literal, TypeVar
|
||||
from typing import Annotated, Any, Generic, Literal, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
import openai.types.chat as openai
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from pydantic.types import UuidVersion
|
||||
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
|
||||
_TaskId = Annotated[UUID, UuidVersion(4)]
|
||||
TaskId = type("TaskId", (UUID,), {})
|
||||
TaskIdParser: TypeAdapter[TaskId] = TypeAdapter(_TaskId)
|
||||
@@ -19,21 +22,60 @@ class TaskType(str, Enum):
|
||||
TaskTypeT = TypeVar("TaskTypeT", bound=TaskType)
|
||||
|
||||
|
||||
class Task(BaseModel, Generic[TaskTypeT]):
|
||||
task_id: TaskId
|
||||
class TaskData(BaseModel, Generic[TaskTypeT]):
|
||||
task_type: TaskTypeT
|
||||
task_data: Any
|
||||
|
||||
|
||||
class ChatCompletionNonStreamingTask(Task[TaskType.ChatCompletionNonStreaming]):
|
||||
class ChatCompletionNonStreamingTask(TaskData[TaskType.ChatCompletionNonStreaming]):
|
||||
task_type: Literal[TaskType.ChatCompletionNonStreaming] = (
|
||||
TaskType.ChatCompletionNonStreaming
|
||||
)
|
||||
task_data: openai.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class ChatCompletionStreamingTask(Task[TaskType.ChatCompletionStreaming]):
|
||||
class ChatCompletionStreamingTask(TaskData[TaskType.ChatCompletionStreaming]):
|
||||
task_type: Literal[TaskType.ChatCompletionStreaming] = (
|
||||
TaskType.ChatCompletionStreaming
|
||||
)
|
||||
task_data: openai.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class TaskStatusType(str, Enum):
|
||||
Pending = "Pending"
|
||||
Running = "Running"
|
||||
Failed = "Failed"
|
||||
Complete = "Complete"
|
||||
|
||||
|
||||
TaskStatusTypeT = TypeVar(
|
||||
"TaskStatusTypeT", bound=Union[TaskStatusType, Literal["Complete"]]
|
||||
)
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel, Generic[TaskStatusTypeT]):
|
||||
task_status: TaskStatusTypeT
|
||||
|
||||
|
||||
class PendingTask(TaskUpdate[TaskStatusType.Pending]):
|
||||
task_status: Literal[TaskStatusType.Pending]
|
||||
|
||||
|
||||
class RunningTask(TaskUpdate[TaskStatusType.Running]):
|
||||
task_status: Literal[TaskStatusType.Running]
|
||||
|
||||
|
||||
class CompletedTask(TaskUpdate[TaskStatusType.Complete]):
|
||||
task_status: Literal[TaskStatusType.Complete]
|
||||
task_artifact: bytes
|
||||
|
||||
|
||||
class FailedTask(TaskUpdate[TaskStatusType.Failed]):
|
||||
task_status: Literal[TaskStatusType.Failed]
|
||||
error_message: Mapping[RunnerId, str]
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_data: TaskData[TaskType]
|
||||
task_status: TaskUpdate[TaskStatusType]
|
||||
on_instance: InstanceId
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
@@ -11,3 +12,9 @@ InstanceIdParser: TypeAdapter[InstanceId] = TypeAdapter(_InstanceId)
|
||||
_RunnerId = Annotated[UUID, UuidVersion(4)]
|
||||
RunnerId = type("RunnerId", (UUID,), {})
|
||||
RunnerIdParser: TypeAdapter[RunnerId] = TypeAdapter(_RunnerId)
|
||||
|
||||
|
||||
class NodeState(str, Enum):
|
||||
Idle = "Idle"
|
||||
Running = "Running"
|
||||
Paused = "Paused"
|
||||
|
||||
@@ -11,14 +11,15 @@ from shared.types.worker.runners import (
|
||||
)
|
||||
|
||||
|
||||
class InstanceBase(BaseModel):
|
||||
instance_id: InstanceId
|
||||
class InstanceState(BaseModel):
|
||||
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
|
||||
|
||||
|
||||
class InstanceData(BaseModel):
|
||||
runner_placements: RunnerPlacement
|
||||
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
|
||||
|
||||
|
||||
class Instance(InstanceBase):
|
||||
class Instance(BaseModel):
|
||||
instance_id: InstanceId
|
||||
instance_data: InstanceData
|
||||
instance_state: InstanceState
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
@@ -59,3 +59,13 @@ class RunnerPlacement(BaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, Shard[ShardType]]
|
||||
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "RunnerPlacement":
|
||||
for runners in self.node_to_runner.values():
|
||||
for runner_id in runners:
|
||||
if runner_id not in self.runner_to_shard:
|
||||
raise ValueError(
|
||||
f"Runner {runner_id} in node_to_runner does not exist in runner_to_shard"
|
||||
)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user