feat: Update Interfaces

This commit is contained in:
Arbion Halili
2025-07-01 18:41:37 +01:00
parent 73ac8969bc
commit 6de1f2883f
13 changed files with 304 additions and 373 deletions

View File

@@ -7,6 +7,7 @@ from typing import (
Sequence,
Tuple,
TypeVar,
Union,
get_args,
)
from uuid import UUID
@@ -19,41 +20,54 @@ EventId = type("EventId", (UUID,), {})
EventIdParser: TypeAdapter[EventId] = TypeAdapter(_EventId)
class EventTypes(str, Enum):
ChatCompletionsRequestStarted = "ChatCompletionsRequestStarted"
ChatCompletionsRequestCompleted = "ChatCompletionsRequestCompleted"
ChatCompletionsRequestFailed = "ChatCompletionsRequestFailed"
InferenceSagaStarted = "InferenceSagaStarted"
InferencePrepareStarted = "InferencePrepareStarted"
InferencePrepareCompleted = "InferencePrepareCompleted"
InferenceTriggerStarted = "InferenceTriggerStarted"
InferenceTriggerCompleted = "InferenceTriggerCompleted"
InferenceCompleted = "InferenceCompleted"
InferenceSagaCompleted = "InferenceSagaCompleted"
InstanceSetupSagaStarted = "InstanceSetupSagaStarted"
InstanceSetupSagaCompleted = "InstanceSetupSagaCompleted"
InstanceSetupSagaFailed = "InstanceSetupSagaFailed"
ShardAssigned = "ShardAssigned"
ShardAssignFailed = "ShardAssignFailed"
ShardUnassigned = "ShardUnassigned"
ShardUnassignFailed = "ShardUnassignFailed"
ShardKilled = "ShardKilled"
ShardDied = "ShardDied"
ShardSpawned = "ShardSpawned"
ShardSpawnedFailed = "ShardSpawnedFailed"
ShardDespawned = "ShardDespawned"
NodeConnected = "NodeConnected"
NodeConnectionProfiled = "NodeConnectionProfiled"
NodeDisconnected = "NodeDisconnected"
NodeStarted = "NodeStarted"
DeviceRegistered = "DeviceRegistered"
DeviceProfiled = "DeviceProfiled"
TokenGenerated = "TokenGenerated"
RepoProgressEvent = "RepoProgressEvent"
TimerScheduled = "TimerScheduled"
class MLXEventTypes(str, Enum):
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
class TaskEventTypes(str, Enum):
TaskCreated = "TaskCreated"
TaskUpdated = "TaskUpdated"
TaskDeleted = "TaskDeleted"
class StreamingEventTypes(str, Enum):
ChunkGenerated = "ChunkGenerated"
class InstanceEventTypes(str, Enum):
InstanceCreated = "InstanceCreated"
InstanceDeleted = "InstanceDeleted"
InstanceReplacedAtomically = "InstanceReplacedAtomically"
InstanceRunnerStateUpdated = "InstanceRunnerStateUpdated"
class NodeEventTypes(str, Enum):
NodeStateUpdated = "NodeStateUpdated"
NodeProfileUpdated = "NodeProfileUpdated"
class EdgeEventTypes(str, Enum):
EdgeCreated = "EdgeCreated"
EdgeUpdated = "EdgeUpdated"
EdgeDeleted = "EdgeDeleted"
class TimerEventTypes(str, Enum):
TimerCreated = "TimerCreated"
TimerFired = "TimerFired"
EventTypes = Union[
TaskEventTypes,
StreamingEventTypes,
InstanceEventTypes,
NodeEventTypes,
EdgeEventTypes,
TimerEventTypes,
MLXEventTypes,
]
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)
TEventType = TypeVar("TEventType", bound=EventTypes, covariant=True)
@@ -73,7 +87,7 @@ class State(BaseModel, Generic[EventTypeT]):
sequence_number: int = Field(default=0, ge=0)
AnnotatedEventType = Annotated[EventTypes, Field(discriminator="event_type")]
AnnotatedEventType = Annotated[Event[EventTypes], Field(discriminator="event_type")]
EventTypeParser: TypeAdapter[AnnotatedEventType] = TypeAdapter(AnnotatedEventType)
Applicator = Callable[[State[EventTypeT], Event[TEventType]], State[EventTypeT]]
@@ -131,8 +145,8 @@ class CommandTypes(str, Enum):
Delete = "Delete"
CommandTypeT = TypeVar("CommandTypeT", bound=EventTypes)
TCommandType = TypeVar("TCommandType", bound=EventTypes, covariant=True)
CommandTypeT = TypeVar("CommandTypeT", bound=CommandTypes)
TCommandType = TypeVar("TCommandType", bound=CommandTypes, covariant=True)
class Command(BaseModel, Generic[TEventType, TCommandType]):

View File

@@ -1,15 +1,31 @@
from __future__ import annotations
from typing import Annotated, List, Literal, Optional
from typing import Annotated, Any, Literal, Tuple
from uuid import UUID
from pydantic import BaseModel, TypeAdapter, UuidVersion
from shared.openai import FinishReason, chat
from shared.types.common import NodeId
from shared.types.events.common import Event, EventTypes
from shared.types.models.common import ModelId
from shared.types.worker.common import InstanceId
from shared.types.events.common import (
Event,
InstanceEventTypes,
MLXEventTypes,
NodeEventTypes,
StreamingEventTypes,
TaskEventTypes,
TimerEventTypes,
)
from shared.types.profiling.common import NodeProfile
from shared.types.tasks.common import (
TaskData,
TaskId,
TaskStatusType,
TaskType,
TaskUpdate,
)
from shared.types.worker.common import InstanceId, NodeState
from shared.types.worker.instances import InstanceData
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
_RequestId = Annotated[UUID, UuidVersion(4)]
RequestId = type("RequestId", (UUID,), {})
@@ -20,305 +36,107 @@ TimerId = type("TimerId", (UUID,), {})
TimerIdParser: TypeAdapter[TimerId] = TypeAdapter(_TimerId)
class Shard(BaseModel):
# TODO: this has changed
model_id: ModelId
class InstanceComputePlan(BaseModel):
# TODO: this has changed
model_id: ModelId
class Timer(BaseModel):
class TimerData(BaseModel):
timer_id: TimerId
# Chat completions ----------------------------------------------------------------
class ChatCompletionsRequestStarted(Event[EventTypes.ChatCompletionsRequestStarted]):
event_type: Literal[EventTypes.ChatCompletionsRequestStarted] = (
EventTypes.ChatCompletionsRequestStarted
)
request_id: RequestId
model_id: ModelId
request: chat.completion_create_params.CompletionCreateParams
class TaskCreated(Event[TaskEventTypes.TaskCreated]):
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
task_id: TaskId
task_data: TaskData[TaskType]
task_state: TaskUpdate[Literal[TaskStatusType.Pending]]
on_instance: InstanceId
class ChatCompletionsRequestCompleted(
Event[EventTypes.ChatCompletionsRequestCompleted]
):
event_type: Literal[EventTypes.ChatCompletionsRequestCompleted] = (
EventTypes.ChatCompletionsRequestCompleted
)
request_id: RequestId
model_id: ModelId
class TaskUpdated(Event[TaskEventTypes.TaskUpdated]):
event_type: Literal[TaskEventTypes.TaskUpdated] = TaskEventTypes.TaskUpdated
task_id: TaskId
update_data: TaskUpdate[TaskStatusType]
class ChatCompletionsRequestFailed(Event[EventTypes.ChatCompletionsRequestFailed]):
event_type: Literal[EventTypes.ChatCompletionsRequestFailed] = (
EventTypes.ChatCompletionsRequestFailed
)
request_id: RequestId
model_id: ModelId
error_message: str
class TaskDeleted(Event[TaskEventTypes.TaskDeleted]):
event_type: Literal[TaskEventTypes.TaskDeleted] = TaskEventTypes.TaskDeleted
task_id: TaskId
# Inference saga ------------------------------------------------------------------
class InferenceSagaStarted(Event[EventTypes.InferenceSagaStarted]):
event_type: Literal[EventTypes.InferenceSagaStarted] = (
EventTypes.InferenceSagaStarted
)
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
request: chat.completion_create_params.CompletionCreateParams
class InferencePrepareStarted(Event[EventTypes.InferencePrepareStarted]):
event_type: Literal[EventTypes.InferencePrepareStarted] = (
EventTypes.InferencePrepareStarted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
hosts: List[str]
shard: Shard # replaces model_id, rank, start_layer, end_layer
request: chat.completion_create_params.CompletionCreateParams
class InferencePrepareCompleted(Event[EventTypes.InferencePrepareCompleted]):
event_type: Literal[EventTypes.InferencePrepareCompleted] = (
EventTypes.InferencePrepareCompleted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
hosts: List[str]
shard: Shard
class InferenceTriggerStarted(Event[EventTypes.InferenceTriggerStarted]):
event_type: Literal[EventTypes.InferenceTriggerStarted] = (
EventTypes.InferenceTriggerStarted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
hosts: List[str]
shard: Shard
request: chat.completion_create_params.CompletionCreateParams
class InferenceTriggerCompleted(Event[EventTypes.InferenceTriggerCompleted]):
event_type: Literal[EventTypes.InferenceTriggerCompleted] = (
EventTypes.InferenceTriggerCompleted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
hosts: List[str]
shard: Shard
class InferenceCompleted(Event[EventTypes.InferenceCompleted]):
event_type: Literal[EventTypes.InferenceCompleted] = EventTypes.InferenceCompleted
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
class InferenceSagaCompleted(Event[EventTypes.InferenceSagaCompleted]):
event_type: Literal[EventTypes.InferenceSagaCompleted] = (
EventTypes.InferenceSagaCompleted
)
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
# Instance setup saga ------------------------------------------------------------
class InstanceSetupSagaStarted(Event[EventTypes.InstanceSetupSagaStarted]):
event_type: Literal[EventTypes.InstanceSetupSagaStarted] = (
EventTypes.InstanceSetupSagaStarted
)
instance_id: str
model_id: ModelId
plan: InstanceComputePlan
class InstanceSetupSagaCompleted(Event[EventTypes.InstanceSetupSagaCompleted]):
event_type: Literal[EventTypes.InstanceSetupSagaCompleted] = (
EventTypes.InstanceSetupSagaCompleted
class InstanceCreated(Event[InstanceEventTypes.InstanceCreated]):
event_type: Literal[InstanceEventTypes.InstanceCreated] = (
InstanceEventTypes.InstanceCreated
)
instance_id: InstanceId
model_id: ModelId
instance_data: InstanceData
class InstanceSetupSagaFailed(Event[EventTypes.InstanceSetupSagaFailed]):
event_type: Literal[EventTypes.InstanceSetupSagaFailed] = (
EventTypes.InstanceSetupSagaFailed
class InstanceDeleted(Event[InstanceEventTypes.InstanceDeleted]):
event_type: Literal[InstanceEventTypes.InstanceDeleted] = (
InstanceEventTypes.InstanceDeleted
)
instance_id: InstanceId
model_id: ModelId
reason: str
# Shard lifecycle -----------------------------------------------------------------
class ShardAssigned(Event[EventTypes.ShardAssigned]):
event_type: Literal[EventTypes.ShardAssigned] = EventTypes.ShardAssigned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardAssignFailed(Event[EventTypes.ShardAssignFailed]):
event_type: Literal[EventTypes.ShardAssignFailed] = EventTypes.ShardAssignFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
reason: str # e.g. "not enough memory"
class ShardUnassigned(Event[EventTypes.ShardUnassigned]):
event_type: Literal[EventTypes.ShardUnassigned] = EventTypes.ShardUnassigned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
reason: str # e.g. "instance did not receive request for 5 mins"
class ShardUnassignFailed(Event[EventTypes.ShardUnassignFailed]):
event_type: Literal[EventTypes.ShardUnassignFailed] = EventTypes.ShardUnassignFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
reason: str # e.g. "process refused to quit"
class ShardKilled(Event[EventTypes.ShardKilled]):
event_type: Literal[EventTypes.ShardKilled] = EventTypes.ShardKilled
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardDied(Event[EventTypes.ShardDied]):
event_type: Literal[EventTypes.ShardDied] = EventTypes.ShardDied
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
error_type: str
error_message: str
traceback: Optional[str] = None
class ShardSpawned(Event[EventTypes.ShardSpawned]):
event_type: Literal[EventTypes.ShardSpawned] = EventTypes.ShardSpawned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardSpawnedFailed(Event[EventTypes.ShardSpawnedFailed]):
event_type: Literal[EventTypes.ShardSpawnedFailed] = EventTypes.ShardSpawnedFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
reason: str # e.g. "not enough memory"
class ShardDespawned(Event[EventTypes.ShardDespawned]):
event_type: Literal[EventTypes.ShardDespawned] = EventTypes.ShardDespawned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
# Node connectivity --------------------------------------------------------------
class NodeConnected(Event[EventTypes.NodeConnected]):
event_type: Literal[EventTypes.NodeConnected] = EventTypes.NodeConnected
remote_node_id: NodeId
connection_id: str
multiaddr: str
remote_multiaddr: str
ip: str
remote_ip: str
class NodeConnectionProfiled(Event[EventTypes.NodeConnectionProfiled]):
event_type: Literal[EventTypes.NodeConnectionProfiled] = (
EventTypes.NodeConnectionProfiled
class InstanceRunnerStateUpdated(Event[InstanceEventTypes.InstanceRunnerStateUpdated]):
event_type: Literal[InstanceEventTypes.InstanceRunnerStateUpdated] = (
InstanceEventTypes.InstanceRunnerStateUpdated
)
remote_node_id: NodeId
connection_id: str
latency_ms: int
bandwidth_bytes_per_second: int
class NodeDisconnected(Event[EventTypes.NodeDisconnected]):
event_type: Literal[EventTypes.NodeDisconnected] = EventTypes.NodeDisconnected
remote_node_id: NodeId
connection_id: str
class NodeStarted(Event[EventTypes.NodeStarted]):
event_type: Literal[EventTypes.NodeStarted] = EventTypes.NodeStarted
# Device metrics -----------------------------------------------------------------
class DeviceRegistered(Event[EventTypes.DeviceRegistered]):
event_type: Literal[EventTypes.DeviceRegistered] = EventTypes.DeviceRegistered
device_id: str
device_model: str
device_type: str
total_memory_bytes: int
available_memory_bytes: int
class DeviceProfiled(Event[EventTypes.DeviceProfiled]):
event_type: Literal[EventTypes.DeviceProfiled] = EventTypes.DeviceProfiled
device_id: str
total_memory_bytes: int
available_memory_bytes: int
total_flops_fp16: int
# Token streaming ----------------------------------------------------------------
class TokenGenerated(Event[EventTypes.TokenGenerated]):
# TODO: replace with matt chunk code
event_type: Literal[EventTypes.TokenGenerated] = EventTypes.TokenGenerated
request_id: RequestId
instance_id: InstanceId
hosts: List[str]
token: int
text: str
finish_reason: FinishReason
state_update: Tuple[RunnerId, RunnerState[RunnerStateType]]
# Repo download progress ----------------------------------------------------------
class RepoProgressEvent(Event[EventTypes.RepoProgressEvent]):
event_type: Literal[EventTypes.RepoProgressEvent] = EventTypes.RepoProgressEvent
repo_id: str
downloaded_bytes: int
total_bytes: int
speed_bytes_per_second: int
class InstanceReplacedAtomically(Event[InstanceEventTypes.InstanceReplacedAtomically]):
event_type: Literal[InstanceEventTypes.InstanceReplacedAtomically] = (
InstanceEventTypes.InstanceReplacedAtomically
)
old_instance_id: InstanceId
new_instance_id: InstanceId
new_instance_data: InstanceData
# Timers -------------------------------------------------------------------------
class TimerScheduled(Event[EventTypes.TimerScheduled]):
event_type: Literal[EventTypes.TimerScheduled] = EventTypes.TimerScheduled
timer: Timer
class MLXInferenceSagaPrepare(Event[MLXEventTypes.MLXInferenceSagaPrepare]):
event_type: Literal[MLXEventTypes.MLXInferenceSagaPrepare] = (
MLXEventTypes.MLXInferenceSagaPrepare
)
task_id: TaskId
instance_id: InstanceId
class TimerFired(Event[EventTypes.TimerFired]):
event_type: Literal[EventTypes.TimerFired] = EventTypes.TimerFired
timer: Timer
class MLXInferenceSagaStartPrepare(Event[MLXEventTypes.MLXInferenceSagaStartPrepare]):
event_type: Literal[MLXEventTypes.MLXInferenceSagaStartPrepare] = (
MLXEventTypes.MLXInferenceSagaStartPrepare
)
task_id: TaskId
instance_id: InstanceId
class NodeProfileUpdated(Event[NodeEventTypes.NodeProfileUpdated]):
event_type: Literal[NodeEventTypes.NodeProfileUpdated] = (
NodeEventTypes.NodeProfileUpdated
)
node_id: NodeId
node_profile: NodeProfile
class NodeStateUpdated(Event[NodeEventTypes.NodeStateUpdated]):
event_type: Literal[NodeEventTypes.NodeStateUpdated] = (
NodeEventTypes.NodeStateUpdated
)
node_id: NodeId
node_state: NodeState
class ChunkGenerated(Event[StreamingEventTypes.ChunkGenerated]):
event_type: Literal[StreamingEventTypes.ChunkGenerated] = (
StreamingEventTypes.ChunkGenerated
)
task_id: TaskId
instance_id: InstanceId
chunk: Any
class TimerScheduled(Event[TimerEventTypes.TimerCreated]):
event_type: Literal[TimerEventTypes.TimerCreated] = TimerEventTypes.TimerCreated
timer_data: TimerData
class TimerFired(Event[TimerEventTypes.TimerFired]):
event_type: Literal[TimerEventTypes.TimerFired] = TimerEventTypes.TimerFired
timer_data: TimerData

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.networking.topology import Topology
from shared.types.profiling.common import NodeProfile
from shared.types.worker.common import NodeState
class ResourceGraph(BaseModel): ...
@@ -12,5 +13,6 @@ class ResourceGraph(BaseModel): ...
def get_graph_of_compute_resources(
network_topology: Topology,
node_states: Mapping[NodeId, NodeState],
node_profiles: Mapping[NodeId, NodeProfile],
) -> ResourceGraph: ...

View File

@@ -1,9 +1,9 @@
from dataclasses import dataclass
from collections.abc import Mapping
from enum import Enum
from typing import Annotated, Generic, NamedTuple, TypeVar, final
from uuid import UUID
from pydantic import BaseModel, IPvAnyAddress, TypeAdapter
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, TypeAdapter
from pydantic.types import UuidVersion
from shared.types.common import NodeId
@@ -13,13 +13,6 @@ EdgeId = type("EdgeId", (UUID,), {})
EdgeIdParser: TypeAdapter[EdgeId] = TypeAdapter(_EdgeId)
@final
class EdgeDataTransferRate(BaseModel):
throughput: float
latency: float
jitter: float
class AddressingProtocol(str, Enum):
IPvAny = "IPvAny"
@@ -28,14 +21,24 @@ class ApplicationProtocol(str, Enum):
MLX = "MLX"
TE = TypeVar("TE", bound=AddressingProtocol)
TF = TypeVar("TF", bound=ApplicationProtocol)
AdP = TypeVar("AdP", bound=AddressingProtocol)
ApP = TypeVar("ApP", bound=ApplicationProtocol)
@final
class EdgeType(BaseModel, Generic[TE, TF]):
addressing_protocol: TE
application_protocol: TF
class EdgeDataTransferRate(BaseModel):
throughput: float
latency: float
jitter: float
class EdgeMetadata(BaseModel, Generic[AdP, ApP]): ...
@final
class EdgeType(BaseModel, Generic[AdP, ApP]):
addressing_protocol: AdP
application_protocol: ApP
@final
@@ -44,41 +47,63 @@ class EdgeDirection(NamedTuple):
sink: NodeId
@dataclass
class EdgeMetadata(BaseModel, Generic[TE, TF]): ...
@final
class MLXEdgeContext(EdgeMetadata[AddressingProtocol.IPvAny, ApplicationProtocol.MLX]):
source_ip: IPvAnyAddress
sink_ip: IPvAnyAddress
@final
class EdgeInfo(BaseModel, Generic[TE, TF]):
edge_type: EdgeType[TE, TF]
class EdgeDataType(str, Enum):
DISCOVERED = "discovered"
PROFILED = "profiled"
UNKNOWN = "unknown"
EdgeDataTypeT = TypeVar("EdgeDataTypeT", bound=EdgeDataType)
class EdgeData(BaseModel, Generic[EdgeDataTypeT]):
edge_data_type: EdgeDataTypeT
class EdgeProfile(EdgeData[EdgeDataType.PROFILED]):
edge_data_transfer_rate: EdgeDataTransferRate
edge_metadata: EdgeMetadata[TE, TF]
@final
class DirectedEdge(BaseModel, Generic[TE, TF]):
def validate_mapping(
edge_data: Mapping[EdgeDataType, EdgeData[EdgeDataType]],
) -> Mapping[EdgeDataType, EdgeData[EdgeDataType]]:
"""Validates that each EdgeData value has an edge_data_type matching its key."""
for key, value in edge_data.items():
if key != value.edge_data_type:
raise ValueError(
f"Edge Data Type Mismatch: key {key} != value {value.edge_data_type}"
)
return edge_data
class Edge(BaseModel, Generic[AdP, ApP, EdgeDataTypeT]):
edge_type: EdgeType[AdP, ApP]
edge_direction: EdgeDirection
edge_identifier: EdgeId
edge_info: EdgeInfo[TE, TF]
edge_data: Annotated[
Mapping[EdgeDataType, EdgeData[EdgeDataType]], AfterValidator(validate_mapping)
]
edge_metadata: EdgeMetadata[AdP, ApP]
"""
an_edge: DirectedEdge[Literal[AddressingProtocol.IPvAny], Literal[ApplicationProtocol.MLX]] = DirectedEdge(
edge_identifier=UUID(),
edge_direction=EdgeDirection(source=NodeId("1"), sink=NodeId("2")),
edge_info=EdgeInfo(
an_edge: UniqueEdge[Literal[AddressingProtocol.IPvAny], Literal[ApplicationProtocol.MLX]] = UniqueEdge(
edge_identifier=EdgeId(UUID().hex),
edge_info=ProfiledEdge(
edge_direction=EdgeDirection(source=NodeId("1"), sink=NodeId("2")),
edge_type=EdgeType(
addressing_protocol=AddressingProtocol.ipv4,
application_protocol=ApplicationProtocol.mlx
addressing_protocol=AddressingProtocol.IPvAny,
application_protocol=ApplicationProtocol.MLX
),
edge_data_transfer_rate=EdgeDataTransferRate(throughput=1000, latency=0.1, jitter=0.01),
edge_metadata=MLXEdgeContext(source_ip=IpV4Addr("192.168.1.1"), sink_ip=IpV4Addr("192.168.1.2"))
edge_data=EdgeData(
edge_data_transfer_rate=EdgeDataTransferRate(throughput=1000, latency=0.1, jitter=0.01)
),
edge_metadata=MLXEdgeContext(source_ip=IPv4Address("192.168.1.1"), sink_ip=IPv4Address("192.168.1.2"))
)
)
"""

View File

@@ -6,9 +6,9 @@ from shared.types.common import NodeId
from shared.types.networking.edges import (
AddressingProtocol,
ApplicationProtocol,
EdgeDirection,
Edge,
EdgeDataType,
EdgeId,
EdgeInfo,
)
TopicName = NewType("TopicName", str)
@@ -21,7 +21,8 @@ class WrappedMessage(BaseModel):
PubSubMessageHandler = Callable[[TopicName, WrappedMessage], None]
NodeConnectedHandler = Callable[
[EdgeId, EdgeDirection, EdgeInfo[AddressingProtocol, ApplicationProtocol]], None
[EdgeId, Edge[AddressingProtocol, ApplicationProtocol, EdgeDataType.DISCOVERED]],
None,
]
NodeDisconnectedHandler = Callable[[EdgeId], None]

View File

@@ -1,22 +1,28 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Literal
from pydantic import BaseModel
from shared.types.networking.edges import (
AddressingProtocol,
ApplicationProtocol,
EdgeDirection,
Edge,
EdgeDataType,
EdgeId,
EdgeInfo,
)
class Topology(BaseModel):
edges: dict[
EdgeId, tuple[EdgeDirection, EdgeInfo[AddressingProtocol, ApplicationProtocol]]
edges: Mapping[
EdgeId,
Edge[AddressingProtocol, ApplicationProtocol, Literal[EdgeDataType.DISCOVERED]],
]
class EdgeMap(BaseModel):
edges: Mapping[EdgeId, Edge[AddressingProtocol, ApplicationProtocol, EdgeDataType]]
class NetworkState(BaseModel):
topology: Topology
history: Sequence[Topology]

View File

@@ -9,6 +9,7 @@ from shared.types.graphs.resource_graph import ResourceGraph
from shared.types.networking.topology import NetworkState
from shared.types.profiling.common import NodeProfile
from shared.types.states.shared import SharedState
from shared.types.worker.common import NodeState
from shared.types.worker.instances import InstanceData, InstanceId
@@ -17,7 +18,8 @@ class ExternalCommand(BaseModel): ...
class MasterState(SharedState):
network_state: NetworkState
node_profiles: dict[NodeId, NodeProfile]
node_profiles: Mapping[NodeId, NodeProfile]
node_states: Mapping[NodeId, NodeState]
job_inbox: Queue[ExternalCommand]
job_outbox: Queue[ExternalCommand]

View File

@@ -3,6 +3,7 @@ from collections.abc import Mapping
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.tasks.common import Task, TaskId, TaskType
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import InstanceData
@@ -10,3 +11,4 @@ from shared.types.worker.instances import InstanceData
class SharedState(BaseModel):
node_id: NodeId
compute_instances: Mapping[InstanceId, InstanceData]
compute_tasks: dict[TaskId, Task[TaskType]]

View File

@@ -1,14 +1,15 @@
from collections.abc import Mapping
from typing import Tuple
from shared.types.models.common import ModelId
from shared.types.states.shared import SharedState
from shared.types.tasks.common import Task, TaskId, TaskType
from shared.types.worker.common import NodeState
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
from shared.types.worker.shards import ShardData, ShardType
class WorkerState(SharedState):
download_state: dict[
node_state: NodeState
download_state: Mapping[
Tuple[ModelId, ShardData[ShardType]], BaseDownloadProgress[DownloadStatus]
]
compute_tasks: dict[TaskId, Task[TaskType]]

View File

@@ -1,11 +1,14 @@
from collections.abc import Mapping
from enum import Enum
from typing import Annotated, Any, Generic, Literal, TypeVar
from typing import Annotated, Any, Generic, Literal, TypeVar, Union
from uuid import UUID
import openai.types.chat as openai
from pydantic import BaseModel, TypeAdapter
from pydantic.types import UuidVersion
from shared.types.worker.common import InstanceId, RunnerId
_TaskId = Annotated[UUID, UuidVersion(4)]
TaskId = type("TaskId", (UUID,), {})
TaskIdParser: TypeAdapter[TaskId] = TypeAdapter(_TaskId)
@@ -19,21 +22,60 @@ class TaskType(str, Enum):
TaskTypeT = TypeVar("TaskTypeT", bound=TaskType)
class Task(BaseModel, Generic[TaskTypeT]):
task_id: TaskId
class TaskData(BaseModel, Generic[TaskTypeT]):
task_type: TaskTypeT
task_data: Any
class ChatCompletionNonStreamingTask(Task[TaskType.ChatCompletionNonStreaming]):
class ChatCompletionNonStreamingTask(TaskData[TaskType.ChatCompletionNonStreaming]):
task_type: Literal[TaskType.ChatCompletionNonStreaming] = (
TaskType.ChatCompletionNonStreaming
)
task_data: openai.completion_create_params.CompletionCreateParams
class ChatCompletionStreamingTask(Task[TaskType.ChatCompletionStreaming]):
class ChatCompletionStreamingTask(TaskData[TaskType.ChatCompletionStreaming]):
task_type: Literal[TaskType.ChatCompletionStreaming] = (
TaskType.ChatCompletionStreaming
)
task_data: openai.completion_create_params.CompletionCreateParams
class TaskStatusType(str, Enum):
Pending = "Pending"
Running = "Running"
Failed = "Failed"
Complete = "Complete"
TaskStatusTypeT = TypeVar(
"TaskStatusTypeT", bound=Union[TaskStatusType, Literal["Complete"]]
)
class TaskUpdate(BaseModel, Generic[TaskStatusTypeT]):
task_status: TaskStatusTypeT
class PendingTask(TaskUpdate[TaskStatusType.Pending]):
task_status: Literal[TaskStatusType.Pending]
class RunningTask(TaskUpdate[TaskStatusType.Running]):
task_status: Literal[TaskStatusType.Running]
class CompletedTask(TaskUpdate[TaskStatusType.Complete]):
task_status: Literal[TaskStatusType.Complete]
task_artifact: bytes
class FailedTask(TaskUpdate[TaskStatusType.Failed]):
task_status: Literal[TaskStatusType.Failed]
error_message: Mapping[RunnerId, str]
class Task(BaseModel):
task_data: TaskData[TaskType]
task_status: TaskUpdate[TaskStatusType]
on_instance: InstanceId

View File

@@ -1,3 +1,4 @@
from enum import Enum
from typing import Annotated
from uuid import UUID
@@ -11,3 +12,9 @@ InstanceIdParser: TypeAdapter[InstanceId] = TypeAdapter(_InstanceId)
_RunnerId = Annotated[UUID, UuidVersion(4)]
RunnerId = type("RunnerId", (UUID,), {})
RunnerIdParser: TypeAdapter[RunnerId] = TypeAdapter(_RunnerId)
class NodeState(str, Enum):
Idle = "Idle"
Running = "Running"
Paused = "Paused"

View File

@@ -11,14 +11,15 @@ from shared.types.worker.runners import (
)
class InstanceBase(BaseModel):
instance_id: InstanceId
class InstanceState(BaseModel):
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
class InstanceData(BaseModel):
runner_placements: RunnerPlacement
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
class Instance(InstanceBase):
class Instance(BaseModel):
instance_id: InstanceId
instance_data: InstanceData
instance_state: InstanceState

View File

@@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from shared.types.common import NodeId
from shared.types.models.common import ModelId
@@ -59,3 +59,13 @@ class RunnerPlacement(BaseModel):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, Shard[ShardType]]
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]
@model_validator(mode="after")
def validate_runners_exist(self) -> "RunnerPlacement":
for runners in self.node_to_runner.values():
for runner_id in runners:
if runner_id not in self.runner_to_shard:
raise ValueError(
f"Runner {runner_id} in node_to_runner does not exist in runner_to_shard"
)
return self