refactor: Use enums

This commit is contained in:
Arbion Halili
2025-06-30 23:45:27 +01:00
parent b758df83cf
commit 53d5d23898
2 changed files with 135 additions and 101 deletions

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import (
Annotated,
Callable,
Generic,
Literal,
Protocol,
Sequence,
Tuple,
@@ -18,40 +18,42 @@ _EventId = Annotated[UUID, UuidVersion(4)]
EventId = type("EventId", (UUID,), {})
EventIdParser: TypeAdapter[EventId] = TypeAdapter(_EventId)
EventTypes = Literal[
"ChatCompletionsRequestStarted",
"ChatCompletionsRequestCompleted",
"ChatCompletionsRequestFailed",
"InferenceSagaStarted",
"InferencePrepareStarted",
"InferencePrepareCompleted",
"InferenceTriggerStarted",
"InferenceTriggerCompleted",
"InferenceCompleted",
"InferenceSagaCompleted",
"InstanceSetupSagaStarted",
"InstanceSetupSagaCompleted",
"InstanceSetupSagaFailed",
"ShardAssigned",
"ShardAssignFailed",
"ShardUnassigned",
"ShardUnassignFailed",
"ShardKilled",
"ShardDied",
"ShardSpawned",
"ShardSpawnedFailed",
"ShardDespawned",
"NodeConnected",
"NodeConnectionProfiled",
"NodeDisconnected",
"NodeStarted",
"DeviceRegistered",
"DeviceProfiled",
"TokenGenerated",
"RepoProgressEvent",
"TimerScheduled",
"TimerFired",
]
class EventTypes(str, Enum):
ChatCompletionsRequestStarted = "ChatCompletionsRequestStarted"
ChatCompletionsRequestCompleted = "ChatCompletionsRequestCompleted"
ChatCompletionsRequestFailed = "ChatCompletionsRequestFailed"
InferenceSagaStarted = "InferenceSagaStarted"
InferencePrepareStarted = "InferencePrepareStarted"
InferencePrepareCompleted = "InferencePrepareCompleted"
InferenceTriggerStarted = "InferenceTriggerStarted"
InferenceTriggerCompleted = "InferenceTriggerCompleted"
InferenceCompleted = "InferenceCompleted"
InferenceSagaCompleted = "InferenceSagaCompleted"
InstanceSetupSagaStarted = "InstanceSetupSagaStarted"
InstanceSetupSagaCompleted = "InstanceSetupSagaCompleted"
InstanceSetupSagaFailed = "InstanceSetupSagaFailed"
ShardAssigned = "ShardAssigned"
ShardAssignFailed = "ShardAssignFailed"
ShardUnassigned = "ShardUnassigned"
ShardUnassignFailed = "ShardUnassignFailed"
ShardKilled = "ShardKilled"
ShardDied = "ShardDied"
ShardSpawned = "ShardSpawned"
ShardSpawnedFailed = "ShardSpawnedFailed"
ShardDespawned = "ShardDespawned"
NodeConnected = "NodeConnected"
NodeConnectionProfiled = "NodeConnectionProfiled"
NodeDisconnected = "NodeDisconnected"
NodeStarted = "NodeStarted"
DeviceRegistered = "DeviceRegistered"
DeviceProfiled = "DeviceProfiled"
TokenGenerated = "TokenGenerated"
RepoProgressEvent = "RepoProgressEvent"
TimerScheduled = "TimerScheduled"
TimerFired = "TimerFired"
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)
TEventType = TypeVar("TEventType", bound=EventTypes, covariant=True)
@@ -122,7 +124,13 @@ _CommandId = Annotated[UUID, UuidVersion(4)]
CommandId = type("CommandId", (UUID,), {})
CommandIdParser: TypeAdapter[CommandId] = TypeAdapter(_CommandId)
CommandTypes = Literal["create", "update", "delete"]
class CommandTypes(str, Enum):
Create = "Create"
Update = "Update"
Delete = "Delete"
CommandTypeT = TypeVar("CommandTypeT", bound=EventTypes)
TCommandType = TypeVar("TCommandType", bound=EventTypes, covariant=True)

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from pydantic import BaseModel, TypeAdapter, UuidVersion
from shared.openai import FinishReason, chat
from shared.types.event_sourcing import Event
from shared.types.event_sourcing import Event, EventTypes
from shared.types.model import ModelId
_NodeId = Annotated[UUID, UuidVersion(4)]
@@ -41,39 +41,49 @@ class Timer(BaseModel):
# Chat completions ----------------------------------------------------------------
class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted"]]):
event_type = "ChatCompletionsRequestStarted"
class ChatCompletionsRequestStarted(Event[EventTypes.ChatCompletionsRequestStarted]):
event_type: Literal[EventTypes.ChatCompletionsRequestStarted] = (
EventTypes.ChatCompletionsRequestStarted
)
request_id: RequestId
model_id: ModelId
request: chat.completion_create_params.CompletionCreateParams
class ChatCompletionsRequestCompleted(
Event[Literal["ChatCompletionsRequestCompleted"]]
Event[EventTypes.ChatCompletionsRequestCompleted]
):
event_type = "ChatCompletionsRequestCompleted"
event_type: Literal[EventTypes.ChatCompletionsRequestCompleted] = (
EventTypes.ChatCompletionsRequestCompleted
)
request_id: RequestId
model_id: ModelId
class ChatCompletionsRequestFailed(Event[Literal["ChatCompletionsRequestFailed"]]):
event_type = "ChatCompletionsRequestFailed"
class ChatCompletionsRequestFailed(Event[EventTypes.ChatCompletionsRequestFailed]):
event_type: Literal[EventTypes.ChatCompletionsRequestFailed] = (
EventTypes.ChatCompletionsRequestFailed
)
request_id: RequestId
model_id: ModelId
error_message: str
# Inference saga ------------------------------------------------------------------
class InferenceSagaStarted(Event[Literal["InferenceSagaStarted"]]):
event_type = "InferenceSagaStarted"
class InferenceSagaStarted(Event[EventTypes.InferenceSagaStarted]):
event_type: Literal[EventTypes.InferenceSagaStarted] = (
EventTypes.InferenceSagaStarted
)
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
request: chat.completion_create_params.CompletionCreateParams
class InferencePrepareStarted(Event[Literal["InferencePrepareStarted"]]):
event_type = "InferencePrepareStarted"
class InferencePrepareStarted(Event[EventTypes.InferencePrepareStarted]):
event_type: Literal[EventTypes.InferencePrepareStarted] = (
EventTypes.InferencePrepareStarted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
@@ -82,8 +92,10 @@ class InferencePrepareStarted(Event[Literal["InferencePrepareStarted"]]):
request: chat.completion_create_params.CompletionCreateParams
class InferencePrepareCompleted(Event[Literal["InferencePrepareCompleted"]]):
event_type = "InferencePrepareCompleted"
class InferencePrepareCompleted(Event[EventTypes.InferencePrepareCompleted]):
event_type: Literal[EventTypes.InferencePrepareCompleted] = (
EventTypes.InferencePrepareCompleted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
@@ -91,8 +103,10 @@ class InferencePrepareCompleted(Event[Literal["InferencePrepareCompleted"]]):
shard: Shard
class InferenceTriggerStarted(Event[Literal["InferenceTriggerStarted"]]):
event_type = "InferenceTriggerStarted"
class InferenceTriggerStarted(Event[EventTypes.InferenceTriggerStarted]):
event_type: Literal[EventTypes.InferenceTriggerStarted] = (
EventTypes.InferenceTriggerStarted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
@@ -101,8 +115,10 @@ class InferenceTriggerStarted(Event[Literal["InferenceTriggerStarted"]]):
request: chat.completion_create_params.CompletionCreateParams
class InferenceTriggerCompleted(Event[Literal["InferenceTriggerCompleted"]]):
event_type = "InferenceTriggerCompleted"
class InferenceTriggerCompleted(Event[EventTypes.InferenceTriggerCompleted]):
event_type: Literal[EventTypes.InferenceTriggerCompleted] = (
EventTypes.InferenceTriggerCompleted
)
request_id: RequestId
instance_id: InstanceId
target_node_id: NodeId
@@ -110,52 +126,60 @@ class InferenceTriggerCompleted(Event[Literal["InferenceTriggerCompleted"]]):
shard: Shard
class InferenceCompleted(Event[Literal["InferenceCompleted"]]):
event_type = "InferenceCompleted"
class InferenceCompleted(Event[EventTypes.InferenceCompleted]):
event_type: Literal[EventTypes.InferenceCompleted] = EventTypes.InferenceCompleted
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
class InferenceSagaCompleted(Event[Literal["InferenceSagaCompleted"]]):
event_type = "InferenceSagaCompleted"
class InferenceSagaCompleted(Event[EventTypes.InferenceSagaCompleted]):
event_type: Literal[EventTypes.InferenceSagaCompleted] = (
EventTypes.InferenceSagaCompleted
)
request_id: RequestId
instance_id: InstanceId
model_id: ModelId
# Instance setup saga ------------------------------------------------------------
class InstanceSetupSagaStarted(Event[Literal["InstanceSetupSagaStarted"]]):
event_type = "InstanceSetupSagaStarted"
class InstanceSetupSagaStarted(Event[EventTypes.InstanceSetupSagaStarted]):
event_type: Literal[EventTypes.InstanceSetupSagaStarted] = (
EventTypes.InstanceSetupSagaStarted
)
instance_id: str
model_id: ModelId
plan: InstanceComputePlan
class InstanceSetupSagaCompleted(Event[Literal["InstanceSetupSagaCompleted"]]):
event_type = "InstanceSetupSagaCompleted"
class InstanceSetupSagaCompleted(Event[EventTypes.InstanceSetupSagaCompleted]):
event_type: Literal[EventTypes.InstanceSetupSagaCompleted] = (
EventTypes.InstanceSetupSagaCompleted
)
instance_id: InstanceId
model_id: ModelId
class InstanceSetupSagaFailed(Event[Literal["InstanceSetupSagaFailed"]]):
event_type = "InstanceSetupSagaFailed"
class InstanceSetupSagaFailed(Event[EventTypes.InstanceSetupSagaFailed]):
event_type: Literal[EventTypes.InstanceSetupSagaFailed] = (
EventTypes.InstanceSetupSagaFailed
)
instance_id: InstanceId
model_id: ModelId
reason: str
# Shard lifecycle -----------------------------------------------------------------
class ShardAssigned(Event[Literal["ShardAssigned"]]):
event_type = "ShardAssigned"
class ShardAssigned(Event[EventTypes.ShardAssigned]):
event_type: Literal[EventTypes.ShardAssigned] = EventTypes.ShardAssigned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardAssignFailed(Event[Literal["ShardAssignFailed"]]):
event_type = "ShardAssignFailed"
class ShardAssignFailed(Event[EventTypes.ShardAssignFailed]):
event_type: Literal[EventTypes.ShardAssignFailed] = EventTypes.ShardAssignFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -163,8 +187,8 @@ class ShardAssignFailed(Event[Literal["ShardAssignFailed"]]):
reason: str # e.g. "not enough memory"
class ShardUnassigned(Event[Literal["ShardUnassigned"]]):
event_type = "ShardUnassigned"
class ShardUnassigned(Event[EventTypes.ShardUnassigned]):
event_type: Literal[EventTypes.ShardUnassigned] = EventTypes.ShardUnassigned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -172,8 +196,8 @@ class ShardUnassigned(Event[Literal["ShardUnassigned"]]):
reason: str # e.g. "instance did not receive request for 5 mins"
class ShardUnassignFailed(Event[Literal["ShardUnassignFailed"]]):
event_type = "ShardUnassignFailed"
class ShardUnassignFailed(Event[EventTypes.ShardUnassignFailed]):
event_type: Literal[EventTypes.ShardUnassignFailed] = EventTypes.ShardUnassignFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -181,16 +205,16 @@ class ShardUnassignFailed(Event[Literal["ShardUnassignFailed"]]):
reason: str # e.g. "process refused to quit"
class ShardKilled(Event[Literal["ShardKilled"]]):
event_type = "ShardKilled"
class ShardKilled(Event[EventTypes.ShardKilled]):
event_type: Literal[EventTypes.ShardKilled] = EventTypes.ShardKilled
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardDied(Event[Literal["ShardDied"]]):
event_type = "ShardDied"
class ShardDied(Event[EventTypes.ShardDied]):
event_type: Literal[EventTypes.ShardDied] = EventTypes.ShardDied
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -200,16 +224,16 @@ class ShardDied(Event[Literal["ShardDied"]]):
traceback: Optional[str] = None
class ShardSpawned(Event[Literal["ShardSpawned"]]):
event_type = "ShardSpawned"
class ShardSpawned(Event[EventTypes.ShardSpawned]):
event_type: Literal[EventTypes.ShardSpawned] = EventTypes.ShardSpawned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
hosts: List[str]
class ShardSpawnedFailed(Event[Literal["ShardSpawnedFailed"]]):
event_type = "ShardSpawnedFailed"
class ShardSpawnedFailed(Event[EventTypes.ShardSpawnedFailed]):
event_type: Literal[EventTypes.ShardSpawnedFailed] = EventTypes.ShardSpawnedFailed
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -217,8 +241,8 @@ class ShardSpawnedFailed(Event[Literal["ShardSpawnedFailed"]]):
reason: str # e.g. "not enough memory"
class ShardDespawned(Event[Literal["ShardDespawned"]]):
event_type = "ShardDespawned"
class ShardDespawned(Event[EventTypes.ShardDespawned]):
event_type: Literal[EventTypes.ShardDespawned] = EventTypes.ShardDespawned
instance_id: InstanceId
shard: Shard
target_node_id: NodeId
@@ -226,8 +250,8 @@ class ShardDespawned(Event[Literal["ShardDespawned"]]):
# Node connectivity --------------------------------------------------------------
class NodeConnected(Event[Literal["NodeConnected"]]):
event_type = "NodeConnected"
class NodeConnected(Event[EventTypes.NodeConnected]):
event_type: Literal[EventTypes.NodeConnected] = EventTypes.NodeConnected
remote_node_id: NodeId
connection_id: str
multiaddr: str
@@ -236,27 +260,29 @@ class NodeConnected(Event[Literal["NodeConnected"]]):
remote_ip: str
class NodeConnectionProfiled(Event[Literal["NodeConnectionProfiled"]]):
event_type = "NodeConnectionProfiled"
class NodeConnectionProfiled(Event[EventTypes.NodeConnectionProfiled]):
event_type: Literal[EventTypes.NodeConnectionProfiled] = (
EventTypes.NodeConnectionProfiled
)
remote_node_id: NodeId
connection_id: str
latency_ms: int
bandwidth_bytes_per_second: int
class NodeDisconnected(Event[Literal["NodeDisconnected"]]):
event_type = "NodeDisconnected"
class NodeDisconnected(Event[EventTypes.NodeDisconnected]):
event_type: Literal[EventTypes.NodeDisconnected] = EventTypes.NodeDisconnected
remote_node_id: NodeId
connection_id: str
class NodeStarted(Event[Literal["NodeStarted"]]):
event_type = "NodeStarted"
class NodeStarted(Event[EventTypes.NodeStarted]):
event_type: Literal[EventTypes.NodeStarted] = EventTypes.NodeStarted
# Device metrics -----------------------------------------------------------------
class DeviceRegistered(Event[Literal["DeviceRegistered"]]):
event_type = "DeviceRegistered"
class DeviceRegistered(Event[EventTypes.DeviceRegistered]):
event_type: Literal[EventTypes.DeviceRegistered] = EventTypes.DeviceRegistered
device_id: str
device_model: str
device_type: str
@@ -264,8 +290,8 @@ class DeviceRegistered(Event[Literal["DeviceRegistered"]]):
available_memory_bytes: int
class DeviceProfiled(Event[Literal["DeviceProfiled"]]):
event_type = "DeviceProfiled"
class DeviceProfiled(Event[EventTypes.DeviceProfiled]):
event_type: Literal[EventTypes.DeviceProfiled] = EventTypes.DeviceProfiled
device_id: str
total_memory_bytes: int
available_memory_bytes: int
@@ -273,9 +299,9 @@ class DeviceProfiled(Event[Literal["DeviceProfiled"]]):
# Token streaming ----------------------------------------------------------------
class TokenGenerated(Event[Literal["TokenGenerated"]]):
class TokenGenerated(Event[EventTypes.TokenGenerated]):
# TODO: replace with matt chunk code
event_type = "TokenGenerated"
event_type: Literal[EventTypes.TokenGenerated] = EventTypes.TokenGenerated
request_id: RequestId
instance_id: InstanceId
hosts: List[str]
@@ -285,8 +311,8 @@ class TokenGenerated(Event[Literal["TokenGenerated"]]):
# Repo download progress ----------------------------------------------------------
class RepoProgressEvent(Event[Literal["RepoProgressEvent"]]):
event_type = "RepoProgressEvent"
class RepoProgressEvent(Event[EventTypes.RepoProgressEvent]):
event_type: Literal[EventTypes.RepoProgressEvent] = EventTypes.RepoProgressEvent
repo_id: str
downloaded_bytes: int
total_bytes: int
@@ -294,11 +320,11 @@ class RepoProgressEvent(Event[Literal["RepoProgressEvent"]]):
# Timers -------------------------------------------------------------------------
class TimerScheduled(Event[Literal["TimerScheduled"]]):
event_type = "TimerScheduled"
class TimerScheduled(Event[EventTypes.TimerScheduled]):
event_type: Literal[EventTypes.TimerScheduled] = EventTypes.TimerScheduled
timer: Timer
class TimerFired(Event[Literal["TimerFired"]]):
event_type = "TimerFired"
class TimerFired(Event[EventTypes.TimerFired]):
event_type: Literal[EventTypes.TimerFired] = EventTypes.TimerFired
timer: Timer