mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
refactor: Use enums
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -18,40 +18,42 @@ _EventId = Annotated[UUID, UuidVersion(4)]
|
||||
EventId = type("EventId", (UUID,), {})
|
||||
EventIdParser: TypeAdapter[EventId] = TypeAdapter(_EventId)
|
||||
|
||||
EventTypes = Literal[
|
||||
"ChatCompletionsRequestStarted",
|
||||
"ChatCompletionsRequestCompleted",
|
||||
"ChatCompletionsRequestFailed",
|
||||
"InferenceSagaStarted",
|
||||
"InferencePrepareStarted",
|
||||
"InferencePrepareCompleted",
|
||||
"InferenceTriggerStarted",
|
||||
"InferenceTriggerCompleted",
|
||||
"InferenceCompleted",
|
||||
"InferenceSagaCompleted",
|
||||
"InstanceSetupSagaStarted",
|
||||
"InstanceSetupSagaCompleted",
|
||||
"InstanceSetupSagaFailed",
|
||||
"ShardAssigned",
|
||||
"ShardAssignFailed",
|
||||
"ShardUnassigned",
|
||||
"ShardUnassignFailed",
|
||||
"ShardKilled",
|
||||
"ShardDied",
|
||||
"ShardSpawned",
|
||||
"ShardSpawnedFailed",
|
||||
"ShardDespawned",
|
||||
"NodeConnected",
|
||||
"NodeConnectionProfiled",
|
||||
"NodeDisconnected",
|
||||
"NodeStarted",
|
||||
"DeviceRegistered",
|
||||
"DeviceProfiled",
|
||||
"TokenGenerated",
|
||||
"RepoProgressEvent",
|
||||
"TimerScheduled",
|
||||
"TimerFired",
|
||||
]
|
||||
|
||||
class EventTypes(str, Enum):
|
||||
ChatCompletionsRequestStarted = "ChatCompletionsRequestStarted"
|
||||
ChatCompletionsRequestCompleted = "ChatCompletionsRequestCompleted"
|
||||
ChatCompletionsRequestFailed = "ChatCompletionsRequestFailed"
|
||||
InferenceSagaStarted = "InferenceSagaStarted"
|
||||
InferencePrepareStarted = "InferencePrepareStarted"
|
||||
InferencePrepareCompleted = "InferencePrepareCompleted"
|
||||
InferenceTriggerStarted = "InferenceTriggerStarted"
|
||||
InferenceTriggerCompleted = "InferenceTriggerCompleted"
|
||||
InferenceCompleted = "InferenceCompleted"
|
||||
InferenceSagaCompleted = "InferenceSagaCompleted"
|
||||
InstanceSetupSagaStarted = "InstanceSetupSagaStarted"
|
||||
InstanceSetupSagaCompleted = "InstanceSetupSagaCompleted"
|
||||
InstanceSetupSagaFailed = "InstanceSetupSagaFailed"
|
||||
ShardAssigned = "ShardAssigned"
|
||||
ShardAssignFailed = "ShardAssignFailed"
|
||||
ShardUnassigned = "ShardUnassigned"
|
||||
ShardUnassignFailed = "ShardUnassignFailed"
|
||||
ShardKilled = "ShardKilled"
|
||||
ShardDied = "ShardDied"
|
||||
ShardSpawned = "ShardSpawned"
|
||||
ShardSpawnedFailed = "ShardSpawnedFailed"
|
||||
ShardDespawned = "ShardDespawned"
|
||||
NodeConnected = "NodeConnected"
|
||||
NodeConnectionProfiled = "NodeConnectionProfiled"
|
||||
NodeDisconnected = "NodeDisconnected"
|
||||
NodeStarted = "NodeStarted"
|
||||
DeviceRegistered = "DeviceRegistered"
|
||||
DeviceProfiled = "DeviceProfiled"
|
||||
TokenGenerated = "TokenGenerated"
|
||||
RepoProgressEvent = "RepoProgressEvent"
|
||||
TimerScheduled = "TimerScheduled"
|
||||
TimerFired = "TimerFired"
|
||||
|
||||
|
||||
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)
|
||||
TEventType = TypeVar("TEventType", bound=EventTypes, covariant=True)
|
||||
|
||||
@@ -122,7 +124,13 @@ _CommandId = Annotated[UUID, UuidVersion(4)]
|
||||
CommandId = type("CommandId", (UUID,), {})
|
||||
CommandIdParser: TypeAdapter[CommandId] = TypeAdapter(_CommandId)
|
||||
|
||||
CommandTypes = Literal["create", "update", "delete"]
|
||||
|
||||
class CommandTypes(str, Enum):
|
||||
Create = "Create"
|
||||
Update = "Update"
|
||||
Delete = "Delete"
|
||||
|
||||
|
||||
CommandTypeT = TypeVar("CommandTypeT", bound=EventTypes)
|
||||
TCommandType = TypeVar("TCommandType", bound=EventTypes, covariant=True)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, TypeAdapter, UuidVersion
|
||||
|
||||
from shared.openai import FinishReason, chat
|
||||
from shared.types.event_sourcing import Event
|
||||
from shared.types.event_sourcing import Event, EventTypes
|
||||
from shared.types.model import ModelId
|
||||
|
||||
_NodeId = Annotated[UUID, UuidVersion(4)]
|
||||
@@ -41,39 +41,49 @@ class Timer(BaseModel):
|
||||
|
||||
|
||||
# Chat completions ----------------------------------------------------------------
|
||||
class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted"]]):
|
||||
event_type = "ChatCompletionsRequestStarted"
|
||||
class ChatCompletionsRequestStarted(Event[EventTypes.ChatCompletionsRequestStarted]):
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestStarted] = (
|
||||
EventTypes.ChatCompletionsRequestStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class ChatCompletionsRequestCompleted(
|
||||
Event[Literal["ChatCompletionsRequestCompleted"]]
|
||||
Event[EventTypes.ChatCompletionsRequestCompleted]
|
||||
):
|
||||
event_type = "ChatCompletionsRequestCompleted"
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestCompleted] = (
|
||||
EventTypes.ChatCompletionsRequestCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class ChatCompletionsRequestFailed(Event[Literal["ChatCompletionsRequestFailed"]]):
|
||||
event_type = "ChatCompletionsRequestFailed"
|
||||
class ChatCompletionsRequestFailed(Event[EventTypes.ChatCompletionsRequestFailed]):
|
||||
event_type: Literal[EventTypes.ChatCompletionsRequestFailed] = (
|
||||
EventTypes.ChatCompletionsRequestFailed
|
||||
)
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
error_message: str
|
||||
|
||||
|
||||
# Inference saga ------------------------------------------------------------------
|
||||
class InferenceSagaStarted(Event[Literal["InferenceSagaStarted"]]):
|
||||
event_type = "InferenceSagaStarted"
|
||||
class InferenceSagaStarted(Event[EventTypes.InferenceSagaStarted]):
|
||||
event_type: Literal[EventTypes.InferenceSagaStarted] = (
|
||||
EventTypes.InferenceSagaStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferencePrepareStarted(Event[Literal["InferencePrepareStarted"]]):
|
||||
event_type = "InferencePrepareStarted"
|
||||
class InferencePrepareStarted(Event[EventTypes.InferencePrepareStarted]):
|
||||
event_type: Literal[EventTypes.InferencePrepareStarted] = (
|
||||
EventTypes.InferencePrepareStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
@@ -82,8 +92,10 @@ class InferencePrepareStarted(Event[Literal["InferencePrepareStarted"]]):
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferencePrepareCompleted(Event[Literal["InferencePrepareCompleted"]]):
|
||||
event_type = "InferencePrepareCompleted"
|
||||
class InferencePrepareCompleted(Event[EventTypes.InferencePrepareCompleted]):
|
||||
event_type: Literal[EventTypes.InferencePrepareCompleted] = (
|
||||
EventTypes.InferencePrepareCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
@@ -91,8 +103,10 @@ class InferencePrepareCompleted(Event[Literal["InferencePrepareCompleted"]]):
|
||||
shard: Shard
|
||||
|
||||
|
||||
class InferenceTriggerStarted(Event[Literal["InferenceTriggerStarted"]]):
|
||||
event_type = "InferenceTriggerStarted"
|
||||
class InferenceTriggerStarted(Event[EventTypes.InferenceTriggerStarted]):
|
||||
event_type: Literal[EventTypes.InferenceTriggerStarted] = (
|
||||
EventTypes.InferenceTriggerStarted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
@@ -101,8 +115,10 @@ class InferenceTriggerStarted(Event[Literal["InferenceTriggerStarted"]]):
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class InferenceTriggerCompleted(Event[Literal["InferenceTriggerCompleted"]]):
|
||||
event_type = "InferenceTriggerCompleted"
|
||||
class InferenceTriggerCompleted(Event[EventTypes.InferenceTriggerCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceTriggerCompleted] = (
|
||||
EventTypes.InferenceTriggerCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
target_node_id: NodeId
|
||||
@@ -110,52 +126,60 @@ class InferenceTriggerCompleted(Event[Literal["InferenceTriggerCompleted"]]):
|
||||
shard: Shard
|
||||
|
||||
|
||||
class InferenceCompleted(Event[Literal["InferenceCompleted"]]):
|
||||
event_type = "InferenceCompleted"
|
||||
class InferenceCompleted(Event[EventTypes.InferenceCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceCompleted] = EventTypes.InferenceCompleted
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class InferenceSagaCompleted(Event[Literal["InferenceSagaCompleted"]]):
|
||||
event_type = "InferenceSagaCompleted"
|
||||
class InferenceSagaCompleted(Event[EventTypes.InferenceSagaCompleted]):
|
||||
event_type: Literal[EventTypes.InferenceSagaCompleted] = (
|
||||
EventTypes.InferenceSagaCompleted
|
||||
)
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
# Instance setup saga ------------------------------------------------------------
|
||||
class InstanceSetupSagaStarted(Event[Literal["InstanceSetupSagaStarted"]]):
|
||||
event_type = "InstanceSetupSagaStarted"
|
||||
class InstanceSetupSagaStarted(Event[EventTypes.InstanceSetupSagaStarted]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaStarted] = (
|
||||
EventTypes.InstanceSetupSagaStarted
|
||||
)
|
||||
instance_id: str
|
||||
model_id: ModelId
|
||||
plan: InstanceComputePlan
|
||||
|
||||
|
||||
class InstanceSetupSagaCompleted(Event[Literal["InstanceSetupSagaCompleted"]]):
|
||||
event_type = "InstanceSetupSagaCompleted"
|
||||
class InstanceSetupSagaCompleted(Event[EventTypes.InstanceSetupSagaCompleted]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaCompleted] = (
|
||||
EventTypes.InstanceSetupSagaCompleted
|
||||
)
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class InstanceSetupSagaFailed(Event[Literal["InstanceSetupSagaFailed"]]):
|
||||
event_type = "InstanceSetupSagaFailed"
|
||||
class InstanceSetupSagaFailed(Event[EventTypes.InstanceSetupSagaFailed]):
|
||||
event_type: Literal[EventTypes.InstanceSetupSagaFailed] = (
|
||||
EventTypes.InstanceSetupSagaFailed
|
||||
)
|
||||
instance_id: InstanceId
|
||||
model_id: ModelId
|
||||
reason: str
|
||||
|
||||
|
||||
# Shard lifecycle -----------------------------------------------------------------
|
||||
class ShardAssigned(Event[Literal["ShardAssigned"]]):
|
||||
event_type = "ShardAssigned"
|
||||
class ShardAssigned(Event[EventTypes.ShardAssigned]):
|
||||
event_type: Literal[EventTypes.ShardAssigned] = EventTypes.ShardAssigned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardAssignFailed(Event[Literal["ShardAssignFailed"]]):
|
||||
event_type = "ShardAssignFailed"
|
||||
class ShardAssignFailed(Event[EventTypes.ShardAssignFailed]):
|
||||
event_type: Literal[EventTypes.ShardAssignFailed] = EventTypes.ShardAssignFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -163,8 +187,8 @@ class ShardAssignFailed(Event[Literal["ShardAssignFailed"]]):
|
||||
reason: str # e.g. "not enough memory"
|
||||
|
||||
|
||||
class ShardUnassigned(Event[Literal["ShardUnassigned"]]):
|
||||
event_type = "ShardUnassigned"
|
||||
class ShardUnassigned(Event[EventTypes.ShardUnassigned]):
|
||||
event_type: Literal[EventTypes.ShardUnassigned] = EventTypes.ShardUnassigned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -172,8 +196,8 @@ class ShardUnassigned(Event[Literal["ShardUnassigned"]]):
|
||||
reason: str # e.g. "instance did not receive request for 5 mins"
|
||||
|
||||
|
||||
class ShardUnassignFailed(Event[Literal["ShardUnassignFailed"]]):
|
||||
event_type = "ShardUnassignFailed"
|
||||
class ShardUnassignFailed(Event[EventTypes.ShardUnassignFailed]):
|
||||
event_type: Literal[EventTypes.ShardUnassignFailed] = EventTypes.ShardUnassignFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -181,16 +205,16 @@ class ShardUnassignFailed(Event[Literal["ShardUnassignFailed"]]):
|
||||
reason: str # e.g. "process refused to quit"
|
||||
|
||||
|
||||
class ShardKilled(Event[Literal["ShardKilled"]]):
|
||||
event_type = "ShardKilled"
|
||||
class ShardKilled(Event[EventTypes.ShardKilled]):
|
||||
event_type: Literal[EventTypes.ShardKilled] = EventTypes.ShardKilled
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardDied(Event[Literal["ShardDied"]]):
|
||||
event_type = "ShardDied"
|
||||
class ShardDied(Event[EventTypes.ShardDied]):
|
||||
event_type: Literal[EventTypes.ShardDied] = EventTypes.ShardDied
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -200,16 +224,16 @@ class ShardDied(Event[Literal["ShardDied"]]):
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
class ShardSpawned(Event[Literal["ShardSpawned"]]):
|
||||
event_type = "ShardSpawned"
|
||||
class ShardSpawned(Event[EventTypes.ShardSpawned]):
|
||||
event_type: Literal[EventTypes.ShardSpawned] = EventTypes.ShardSpawned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
hosts: List[str]
|
||||
|
||||
|
||||
class ShardSpawnedFailed(Event[Literal["ShardSpawnedFailed"]]):
|
||||
event_type = "ShardSpawnedFailed"
|
||||
class ShardSpawnedFailed(Event[EventTypes.ShardSpawnedFailed]):
|
||||
event_type: Literal[EventTypes.ShardSpawnedFailed] = EventTypes.ShardSpawnedFailed
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -217,8 +241,8 @@ class ShardSpawnedFailed(Event[Literal["ShardSpawnedFailed"]]):
|
||||
reason: str # e.g. "not enough memory"
|
||||
|
||||
|
||||
class ShardDespawned(Event[Literal["ShardDespawned"]]):
|
||||
event_type = "ShardDespawned"
|
||||
class ShardDespawned(Event[EventTypes.ShardDespawned]):
|
||||
event_type: Literal[EventTypes.ShardDespawned] = EventTypes.ShardDespawned
|
||||
instance_id: InstanceId
|
||||
shard: Shard
|
||||
target_node_id: NodeId
|
||||
@@ -226,8 +250,8 @@ class ShardDespawned(Event[Literal["ShardDespawned"]]):
|
||||
|
||||
|
||||
# Node connectivity --------------------------------------------------------------
|
||||
class NodeConnected(Event[Literal["NodeConnected"]]):
|
||||
event_type = "NodeConnected"
|
||||
class NodeConnected(Event[EventTypes.NodeConnected]):
|
||||
event_type: Literal[EventTypes.NodeConnected] = EventTypes.NodeConnected
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
multiaddr: str
|
||||
@@ -236,27 +260,29 @@ class NodeConnected(Event[Literal["NodeConnected"]]):
|
||||
remote_ip: str
|
||||
|
||||
|
||||
class NodeConnectionProfiled(Event[Literal["NodeConnectionProfiled"]]):
|
||||
event_type = "NodeConnectionProfiled"
|
||||
class NodeConnectionProfiled(Event[EventTypes.NodeConnectionProfiled]):
|
||||
event_type: Literal[EventTypes.NodeConnectionProfiled] = (
|
||||
EventTypes.NodeConnectionProfiled
|
||||
)
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
latency_ms: int
|
||||
bandwidth_bytes_per_second: int
|
||||
|
||||
|
||||
class NodeDisconnected(Event[Literal["NodeDisconnected"]]):
|
||||
event_type = "NodeDisconnected"
|
||||
class NodeDisconnected(Event[EventTypes.NodeDisconnected]):
|
||||
event_type: Literal[EventTypes.NodeDisconnected] = EventTypes.NodeDisconnected
|
||||
remote_node_id: NodeId
|
||||
connection_id: str
|
||||
|
||||
|
||||
class NodeStarted(Event[Literal["NodeStarted"]]):
|
||||
event_type = "NodeStarted"
|
||||
class NodeStarted(Event[EventTypes.NodeStarted]):
|
||||
event_type: Literal[EventTypes.NodeStarted] = EventTypes.NodeStarted
|
||||
|
||||
|
||||
# Device metrics -----------------------------------------------------------------
|
||||
class DeviceRegistered(Event[Literal["DeviceRegistered"]]):
|
||||
event_type = "DeviceRegistered"
|
||||
class DeviceRegistered(Event[EventTypes.DeviceRegistered]):
|
||||
event_type: Literal[EventTypes.DeviceRegistered] = EventTypes.DeviceRegistered
|
||||
device_id: str
|
||||
device_model: str
|
||||
device_type: str
|
||||
@@ -264,8 +290,8 @@ class DeviceRegistered(Event[Literal["DeviceRegistered"]]):
|
||||
available_memory_bytes: int
|
||||
|
||||
|
||||
class DeviceProfiled(Event[Literal["DeviceProfiled"]]):
|
||||
event_type = "DeviceProfiled"
|
||||
class DeviceProfiled(Event[EventTypes.DeviceProfiled]):
|
||||
event_type: Literal[EventTypes.DeviceProfiled] = EventTypes.DeviceProfiled
|
||||
device_id: str
|
||||
total_memory_bytes: int
|
||||
available_memory_bytes: int
|
||||
@@ -273,9 +299,9 @@ class DeviceProfiled(Event[Literal["DeviceProfiled"]]):
|
||||
|
||||
|
||||
# Token streaming ----------------------------------------------------------------
|
||||
class TokenGenerated(Event[Literal["TokenGenerated"]]):
|
||||
class TokenGenerated(Event[EventTypes.TokenGenerated]):
|
||||
# TODO: replace with matt chunk code
|
||||
event_type = "TokenGenerated"
|
||||
event_type: Literal[EventTypes.TokenGenerated] = EventTypes.TokenGenerated
|
||||
request_id: RequestId
|
||||
instance_id: InstanceId
|
||||
hosts: List[str]
|
||||
@@ -285,8 +311,8 @@ class TokenGenerated(Event[Literal["TokenGenerated"]]):
|
||||
|
||||
|
||||
# Repo download progress ----------------------------------------------------------
|
||||
class RepoProgressEvent(Event[Literal["RepoProgressEvent"]]):
|
||||
event_type = "RepoProgressEvent"
|
||||
class RepoProgressEvent(Event[EventTypes.RepoProgressEvent]):
|
||||
event_type: Literal[EventTypes.RepoProgressEvent] = EventTypes.RepoProgressEvent
|
||||
repo_id: str
|
||||
downloaded_bytes: int
|
||||
total_bytes: int
|
||||
@@ -294,11 +320,11 @@ class RepoProgressEvent(Event[Literal["RepoProgressEvent"]]):
|
||||
|
||||
|
||||
# Timers -------------------------------------------------------------------------
|
||||
class TimerScheduled(Event[Literal["TimerScheduled"]]):
|
||||
event_type = "TimerScheduled"
|
||||
class TimerScheduled(Event[EventTypes.TimerScheduled]):
|
||||
event_type: Literal[EventTypes.TimerScheduled] = EventTypes.TimerScheduled
|
||||
timer: Timer
|
||||
|
||||
|
||||
class TimerFired(Event[Literal["TimerFired"]]):
|
||||
event_type = "TimerFired"
|
||||
class TimerFired(Event[EventTypes.TimerFired]):
|
||||
event_type: Literal[EventTypes.TimerFired] = EventTypes.TimerFired
|
||||
timer: Timer
|
||||
|
||||
Reference in New Issue
Block a user