mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Simplify Task type + merge control & data plane types into single type
This commit is contained in:
@@ -15,7 +15,7 @@ from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
from engines.mlx.auto_parallel import auto_parallel
|
||||
from shared.types.tasks.common import CompletionCreateParams
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import ShardMeta
|
||||
from worker.runner.communication import runner_print
|
||||
@@ -96,7 +96,7 @@ def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWra
|
||||
async def apply_chat_template(
|
||||
mlx_executor: concurrent.futures.ThreadPoolExecutor,
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: CompletionCreateParams,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
loop: AbstractEventLoop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from typing import Protocol
|
||||
|
||||
from shared.types.graphs.topology import Topology
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.model import ModelInfo
|
||||
from shared.types.models.sources import ModelSource
|
||||
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.downloads import DownloadProgress
|
||||
from shared.types.worker.instances import Instance
|
||||
|
||||
|
||||
class ControlPlaneAPI(Protocol):
|
||||
def get_control_plane_topology(self) -> ControlPlaneTopology: ...
|
||||
|
||||
def get_data_plane_topology(self) -> DataPlaneTopology: ...
|
||||
class ClusterAPI(Protocol):
|
||||
def get_topology(self) -> Topology: ...
|
||||
|
||||
def list_instances(self) -> list[Instance]: ...
|
||||
|
||||
|
||||
@@ -54,13 +54,12 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
|
||||
|
||||
# What The Master Cares About
|
||||
MasterEventCategories = (
|
||||
Literal[EventCategoryEnum.MutatesControlPlaneState]
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
| Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
| Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
| Literal[EventCategoryEnum.MutatesDataPlaneState]
|
||||
)
|
||||
|
||||
|
||||
@@ -119,13 +118,8 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/topology/control_plane")
|
||||
def get_control_plane_topology():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
@app.get("/topology/data_plane")
|
||||
def get_data_plane_topology():
|
||||
@app.get("/topology")
|
||||
def get_topology():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
|
||||
23
master/placement.py
Normal file
23
master/placement.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from queue import Queue
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from shared.types.events.common import BaseEvent, EventCategory
|
||||
from shared.types.graphs.topology import Topology
|
||||
from shared.types.states.master import CachePolicy, CachePolicyType
|
||||
from shared.types.tasks.common import Task
|
||||
from shared.types.worker.instances import InstanceId, InstanceParams
|
||||
|
||||
|
||||
def get_instance_placement(
|
||||
inbox: Queue[Task],
|
||||
outbox: Queue[Task],
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
cache_policy: CachePolicy[CachePolicyType],
|
||||
) -> Mapping[InstanceId, InstanceParams]: ...
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
target_instances: Mapping[InstanceId, InstanceParams],
|
||||
) -> Sequence[BaseEvent[EventCategory]]: ...
|
||||
@@ -108,11 +108,8 @@ class AsyncStateManagerMapping(TypedDict):
|
||||
MutatesTaskSagaState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
]
|
||||
MutatesControlPlaneState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesControlPlaneState]
|
||||
]
|
||||
MutatesDataPlaneState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesDataPlaneState]
|
||||
MutatesTopologyState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
]
|
||||
MutatesRunnerStatus: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
|
||||
@@ -7,8 +7,7 @@ from shared.types.events.common import EventCategoryEnum, State
|
||||
class SyncStateManagerMapping(TypedDict):
|
||||
MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
MutatesControlPlaneState: State[Literal[EventCategoryEnum.MutatesControlPlaneState]]
|
||||
MutatesDataPlaneState: State[Literal[EventCategoryEnum.MutatesDataPlaneState]]
|
||||
MutatesTopologyState: State[Literal[EventCategoryEnum.MutatesTopologyState]]
|
||||
MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]]
|
||||
MutatesNodePerformanceState: State[
|
||||
|
||||
@@ -33,7 +33,7 @@ class _EdgeWrapper[EdgeTypeT, EdgeIdT]:
|
||||
edge_data: EdgeData[EdgeTypeT]
|
||||
|
||||
|
||||
class NetworkXGraph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
|
||||
class Graph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
|
||||
edge_base: TypeAdapter[EdgeTypeT]
|
||||
vertex_base: TypeAdapter[VertexTypeT]
|
||||
|
||||
@@ -2,10 +2,10 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.tasks.common import CompletionCreateParams, TaskId
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams, TaskId
|
||||
|
||||
|
||||
class ChatTask(BaseModel):
|
||||
task_id: TaskId
|
||||
kind: Literal["chat"] = "chat"
|
||||
task_data: CompletionCreateParams
|
||||
task_data: ChatCompletionTaskParams
|
||||
|
||||
@@ -60,13 +60,10 @@ class NodePerformanceEventTypes(str, Enum):
|
||||
NodePerformanceMeasured = "NodePerformanceMeasured"
|
||||
|
||||
|
||||
class DataPlaneEventTypes(str, Enum):
|
||||
DataPlaneEdgeCreated = "DataPlaneEdgeCreated"
|
||||
DataPlaneEdgeReplacedAtomically = "DataPlaneEdgeReplacedAtomically"
|
||||
DataPlaneEdgeDeleted = "DataPlaneEdgeDeleted"
|
||||
|
||||
|
||||
class ControlPlaneEventTypes(str, Enum):
|
||||
class TopologyEventTypes(str, Enum):
|
||||
TopologyEdgeCreated = "TopologyEdgeCreated"
|
||||
TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically"
|
||||
TopologyEdgeDeleted = "TopologyEdgeDeleted"
|
||||
WorkerConnected = "WorkerConnected"
|
||||
WorkerStatusUpdated = "WorkerStatusUpdated"
|
||||
WorkerDisconnected = "WorkerDisconnected"
|
||||
@@ -84,8 +81,7 @@ EVENT_TYPE_ENUMS = [
|
||||
InstanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
ControlPlaneEventTypes,
|
||||
TopologyEventTypes,
|
||||
TimerEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
]
|
||||
@@ -98,8 +94,7 @@ EventTypes = (
|
||||
| InstanceEventTypes
|
||||
| RunnerStatusEventTypes
|
||||
| NodePerformanceEventTypes
|
||||
| ControlPlaneEventTypes
|
||||
| DataPlaneEventTypes
|
||||
| TopologyEventTypes
|
||||
| TimerEventTypes
|
||||
| TaskSagaEventTypes
|
||||
)
|
||||
@@ -114,18 +109,17 @@ class EventCategoryEnum(StrEnum):
|
||||
MutatesRunnerStatus = "MutatesRunnerStatus"
|
||||
MutatesInstanceState = "MutatesInstanceState"
|
||||
MutatesNodePerformanceState = "MutatesNodePerformanceState"
|
||||
MutatesControlPlaneState = "MutatesControlPlaneState"
|
||||
MutatesDataPlaneState = "MutatesDataPlaneState"
|
||||
MutatesTopologyState = "MutatesTopologyState"
|
||||
|
||||
|
||||
EventCategory = (
|
||||
Literal[EventCategoryEnum.MutatesControlPlaneState]
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
| Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
| Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
| Literal[EventCategoryEnum.MutatesDataPlaneState]
|
||||
| Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
)
|
||||
|
||||
EventCategories = FrozenSet[EventCategory]
|
||||
|
||||
@@ -6,8 +6,6 @@ from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import GenerationChunk
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
ControlPlaneEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
EventCategoryEnum,
|
||||
InstanceEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
@@ -15,33 +13,23 @@ from shared.types.events.common import (
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
TopologyEventTypes,
|
||||
)
|
||||
from shared.types.networking.control_plane import (
|
||||
ControlPlaneEdgeId,
|
||||
ControlPlaneEdgeType,
|
||||
)
|
||||
from shared.types.networking.data_plane import (
|
||||
DataPlaneEdge,
|
||||
DataPlaneEdgeId,
|
||||
DataPlaneEdgeProfile,
|
||||
from shared.types.graphs.topology import (
|
||||
TopologyEdge,
|
||||
TopologyEdgeId,
|
||||
TopologyEdgeProfile,
|
||||
TopologyNode,
|
||||
)
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
from shared.types.tasks.common import (
|
||||
BaseTaskData,
|
||||
TaskId,
|
||||
TaskState,
|
||||
TaskStatusOtherType,
|
||||
TaskStatusType,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.tasks.common import Task, TaskId, TaskStatus
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
TaskEvent = BaseEvent[EventCategoryEnum.MutatesTaskState]
|
||||
InstanceEvent = BaseEvent[EventCategoryEnum.MutatesInstanceState]
|
||||
ControlPlaneEvent = BaseEvent[EventCategoryEnum.MutatesControlPlaneState]
|
||||
DataPlaneEvent = BaseEvent[EventCategoryEnum.MutatesDataPlaneState]
|
||||
TopologyEvent = BaseEvent[EventCategoryEnum.MutatesTopologyState]
|
||||
NodePerformanceEvent = BaseEvent[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
|
||||
|
||||
@@ -49,9 +37,7 @@ class TaskCreated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEven
|
||||
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
task_id: TaskId
|
||||
task_data: BaseTaskData[TaskType]
|
||||
task_state: TaskState[Literal[TaskStatusOtherType.Pending], TaskType]
|
||||
on_instance: InstanceId
|
||||
task: Task
|
||||
|
||||
|
||||
# Covers Cancellation Of Task, Non-Cancelled Tasks Perist
|
||||
@@ -64,7 +50,8 @@ class TaskDeleted(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEven
|
||||
class TaskStateUpdated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskStateUpdated]]):
|
||||
event_type: Literal[TaskEventTypes.TaskStateUpdated] = TaskEventTypes.TaskStateUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
task_state: TaskState[TaskStatusType, TaskType]
|
||||
task_id: TaskId
|
||||
task_status: TaskStatus
|
||||
|
||||
|
||||
class InstanceCreated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceCreated]]):
|
||||
@@ -130,23 +117,23 @@ class NodePerformanceMeasured(BaseEvent[EventCategoryEnum.MutatesNodePerformance
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerConnected]]):
|
||||
event_type: Literal[ControlPlaneEventTypes.WorkerConnected] = ControlPlaneEventTypes.WorkerConnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
|
||||
edge: DataPlaneEdge
|
||||
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerConnected]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerConnected] = TopologyEventTypes.WorkerConnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
edge: TopologyEdge
|
||||
|
||||
|
||||
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerStatusUpdated]]):
|
||||
event_type: Literal[ControlPlaneEventTypes.WorkerStatusUpdated] = ControlPlaneEventTypes.WorkerStatusUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
|
||||
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerStatusUpdated]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerStatusUpdated] = TopologyEventTypes.WorkerStatusUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
node_id: NodeId
|
||||
node_state: NodeStatus
|
||||
|
||||
|
||||
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerDisconnected]]):
|
||||
event_type: Literal[ControlPlaneEventTypes.WorkerDisconnected] = ControlPlaneEventTypes.WorkerDisconnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
|
||||
vertex_id: ControlPlaneEdgeId
|
||||
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerDisconnected]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerDisconnected] = TopologyEventTypes.WorkerDisconnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
vertex_id: NodeId
|
||||
|
||||
|
||||
class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[StreamingEventTypes.ChunkGenerated]]):
|
||||
@@ -156,23 +143,23 @@ class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[Strea
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class DataPlaneEdgeCreated(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeCreated]]):
|
||||
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeCreated] = DataPlaneEventTypes.DataPlaneEdgeCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
|
||||
vertex: ControlPlaneEdgeType
|
||||
class TopologyEdgeCreated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeCreated]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeCreated] = TopologyEventTypes.TopologyEdgeCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
vertex: TopologyNode
|
||||
|
||||
|
||||
class DataPlaneEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically]]):
|
||||
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically] = DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically
|
||||
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
|
||||
edge_id: DataPlaneEdgeId
|
||||
edge_profile: DataPlaneEdgeProfile
|
||||
class TopologyEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically] = TopologyEventTypes.TopologyEdgeReplacedAtomically
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
edge_id: TopologyEdgeId
|
||||
edge_profile: TopologyEdgeProfile
|
||||
|
||||
|
||||
class DataPlaneEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeDeleted]]):
|
||||
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeDeleted] = DataPlaneEventTypes.DataPlaneEdgeDeleted
|
||||
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
|
||||
edge_id: DataPlaneEdgeId
|
||||
class TopologyEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeDeleted]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeDeleted] = TopologyEventTypes.TopologyEdgeDeleted
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
edge_id: TopologyEdgeId
|
||||
|
||||
"""
|
||||
TEST_EVENT_CATEGORIES_TYPE = FrozenSet[
|
||||
|
||||
@@ -6,8 +6,6 @@ from pydantic import Field, TypeAdapter
|
||||
from shared.constants import get_error_reporting_message
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
ControlPlaneEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
EventCategories,
|
||||
EventTypes,
|
||||
InstanceEventTypes,
|
||||
@@ -16,12 +14,10 @@ from shared.types.events.common import (
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
TopologyEventTypes,
|
||||
)
|
||||
from shared.types.events.events import (
|
||||
ChunkGenerated,
|
||||
DataPlaneEdgeCreated,
|
||||
DataPlaneEdgeDeleted,
|
||||
DataPlaneEdgeReplacedAtomically,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
@@ -32,6 +28,9 @@ from shared.types.events.events import (
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStateUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
WorkerConnected,
|
||||
WorkerDisconnected,
|
||||
WorkerStatusUpdated,
|
||||
@@ -59,13 +58,13 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
|
||||
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
|
||||
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
|
||||
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
|
||||
ControlPlaneEventTypes.WorkerConnected: WorkerConnected,
|
||||
ControlPlaneEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
ControlPlaneEventTypes.WorkerDisconnected: WorkerDisconnected,
|
||||
TopologyEventTypes.WorkerConnected: WorkerConnected,
|
||||
TopologyEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
TopologyEventTypes.WorkerDisconnected: WorkerDisconnected,
|
||||
StreamingEventTypes.ChunkGenerated: ChunkGenerated,
|
||||
DataPlaneEventTypes.DataPlaneEdgeCreated: DataPlaneEdgeCreated,
|
||||
DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically: DataPlaneEdgeReplacedAtomically,
|
||||
DataPlaneEventTypes.DataPlaneEdgeDeleted: DataPlaneEdgeDeleted,
|
||||
TopologyEventTypes.TopologyEdgeCreated: TopologyEdgeCreated,
|
||||
TopologyEventTypes.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
|
||||
TopologyEventTypes.TopologyEdgeDeleted: TopologyEdgeDeleted,
|
||||
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
|
||||
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
|
||||
}
|
||||
@@ -115,9 +114,9 @@ Event = (
|
||||
| WorkerStatusUpdated
|
||||
| WorkerDisconnected
|
||||
| ChunkGenerated
|
||||
| DataPlaneEdgeCreated
|
||||
| DataPlaneEdgeReplacedAtomically
|
||||
| DataPlaneEdgeDeleted
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeReplacedAtomically
|
||||
| TopologyEdgeDeleted
|
||||
| MLXInferenceSagaPrepare
|
||||
| MLXInferenceSagaStartPrepare
|
||||
)
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
|
||||
|
||||
class ResourceGraph(BaseModel): ...
|
||||
|
||||
|
||||
def get_graph_of_compute_resources(
|
||||
control_plane_topology: ControlPlaneTopology,
|
||||
data_plane_topology: DataPlaneTopology,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> ResourceGraph: ...
|
||||
48
shared/types/graphs/topology.py
Normal file
48
shared/types/graphs/topology.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel, IPvAnyAddress
|
||||
|
||||
from shared.graphs import Graph
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
|
||||
|
||||
class TopologyEdgeId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class TopologyEdgeProfile(BaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
|
||||
class TopologyEdge(BaseModel):
|
||||
source_ip: IPvAnyAddress
|
||||
sink_ip: IPvAnyAddress
|
||||
edge_profile: TopologyEdgeProfile
|
||||
|
||||
|
||||
class TopologyNode(BaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class Topology(
|
||||
Graph[
|
||||
TopologyEdge,
|
||||
TopologyNode,
|
||||
TopologyEdgeId,
|
||||
NodeId,
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class OrphanedPartOfTopology(
|
||||
Graph[
|
||||
TopologyEdge,
|
||||
TopologyNode,
|
||||
TopologyEdgeId,
|
||||
NodeId,
|
||||
]
|
||||
):
|
||||
pass
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import TypeAlias
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from shared.types.graphs.common import Edge
|
||||
|
||||
|
||||
class ControlPlaneEdgeId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
ControlPlaneEdgeType: TypeAlias = Edge[None, ControlPlaneEdgeId, NodeId]
|
||||
@@ -1,68 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal, TypeVar, Union, final
|
||||
|
||||
from pydantic import BaseModel, Field, IPvAnyAddress, TypeAdapter
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from shared.types.graphs.common import Edge
|
||||
|
||||
|
||||
class DataPlaneEdgeId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class AddressingProtocol(str, Enum):
|
||||
IPvAnyAddress = "IPvAnyAddress"
|
||||
|
||||
|
||||
class ApplicationProtocol(str, Enum):
|
||||
MLX = "MLX"
|
||||
|
||||
|
||||
AdP = TypeVar("AdP", bound=AddressingProtocol)
|
||||
ApP = TypeVar("ApP", bound=ApplicationProtocol)
|
||||
|
||||
|
||||
@final
|
||||
class DataPlaneEdgeProfile(BaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
|
||||
class CommonDataPlaneEdgeData(BaseModel):
|
||||
edge_data_transfer_rate: DataPlaneEdgeProfile | None = None
|
||||
|
||||
|
||||
class MlxEdgeMetadata(BaseModel):
|
||||
source_ip: IPvAnyAddress
|
||||
sink_ip: IPvAnyAddress
|
||||
|
||||
|
||||
class BaseDataPlaneEdgeData[AdP: AddressingProtocol, ApP: ApplicationProtocol](
|
||||
BaseModel
|
||||
):
|
||||
addressing_protocol: AdP
|
||||
application_protocol: ApP
|
||||
common_data: CommonDataPlaneEdgeData
|
||||
|
||||
|
||||
class MlxEdge(
|
||||
BaseDataPlaneEdgeData[AddressingProtocol.IPvAnyAddress, ApplicationProtocol.MLX]
|
||||
):
|
||||
addressing_protocol: Literal[AddressingProtocol.IPvAnyAddress] = (
|
||||
AddressingProtocol.IPvAnyAddress
|
||||
)
|
||||
application_protocol: Literal[ApplicationProtocol.MLX] = ApplicationProtocol.MLX
|
||||
mlx_metadata: MlxEdgeMetadata
|
||||
|
||||
|
||||
DataPlaneEdgeData = Union[MlxEdge]
|
||||
|
||||
_DataPlaneEdgeData = Annotated[
|
||||
DataPlaneEdgeData,
|
||||
Field(discriminator="addressing_protocol"),
|
||||
]
|
||||
DataPlaneEdgeAdapter: TypeAdapter[DataPlaneEdgeData] = TypeAdapter(_DataPlaneEdgeData)
|
||||
|
||||
DataPlaneEdge = Edge[DataPlaneEdgeData, DataPlaneEdgeId, NodeId]
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Callable, NewType, Protocol
|
||||
|
||||
from shared.types.networking.control_plane import (
|
||||
ControlPlaneEdgeId,
|
||||
ControlPlaneEdgeType,
|
||||
)
|
||||
|
||||
TopicName = NewType("TopicName", str)
|
||||
|
||||
PubSubMessageHandler = Callable[[TopicName, object], None]
|
||||
NodeConnectedHandler = Callable[
|
||||
[
|
||||
ControlPlaneEdgeId,
|
||||
ControlPlaneEdgeType,
|
||||
],
|
||||
None,
|
||||
]
|
||||
NodeDisconnectedHandler = Callable[[ControlPlaneEdgeId], None]
|
||||
|
||||
|
||||
class DiscoveryService(Protocol):
|
||||
def on_node_connected(self, handler: NodeConnectedHandler) -> None: ...
|
||||
def on_node_disconnected(self, handler: NodeDisconnectedHandler) -> None: ...
|
||||
|
||||
|
||||
class PubSubService(Protocol):
|
||||
def on_message_received(
|
||||
self, topic_name: TopicName, handler: PubSubMessageHandler
|
||||
) -> None: ...
|
||||
@@ -1,45 +0,0 @@
|
||||
from shared.graphs.networkx import NetworkXGraph
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.networking.control_plane import ControlPlaneEdgeId
|
||||
from shared.types.networking.data_plane import (
|
||||
DataPlaneEdgeData,
|
||||
DataPlaneEdgeId,
|
||||
)
|
||||
from shared.types.worker.common import NodeStatus
|
||||
|
||||
|
||||
class DataPlaneTopology(
|
||||
NetworkXGraph[
|
||||
DataPlaneEdgeData,
|
||||
None,
|
||||
DataPlaneEdgeId,
|
||||
NodeId,
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class OrphanedPartOfDataPlaneTopology(
|
||||
NetworkXGraph[
|
||||
DataPlaneEdgeData,
|
||||
None,
|
||||
DataPlaneEdgeId,
|
||||
NodeId,
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ControlPlaneTopology(NetworkXGraph[None, NodeStatus, ControlPlaneEdgeId, NodeId]):
|
||||
pass
|
||||
|
||||
|
||||
class OrphanedPartOfControlPlaneTopology(
|
||||
NetworkXGraph[
|
||||
None,
|
||||
NodeStatus,
|
||||
ControlPlaneEdgeId,
|
||||
NodeId,
|
||||
]
|
||||
):
|
||||
pass
|
||||
@@ -6,29 +6,17 @@ from typing import Generic, Literal, TypeVar
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
State,
|
||||
)
|
||||
from shared.types.graphs.resource_graph import ResourceGraph
|
||||
from shared.types.networking.data_plane import (
|
||||
DataPlaneEdge,
|
||||
DataPlaneEdgeAdapter,
|
||||
DataPlaneEdgeId,
|
||||
)
|
||||
from shared.types.networking.topology import (
|
||||
ControlPlaneTopology,
|
||||
DataPlaneTopology,
|
||||
OrphanedPartOfControlPlaneTopology,
|
||||
OrphanedPartOfDataPlaneTopology,
|
||||
from shared.types.events.common import EventCategoryEnum, State
|
||||
from shared.types.graphs.topology import (
|
||||
OrphanedPartOfTopology,
|
||||
Topology,
|
||||
TopologyEdge,
|
||||
TopologyEdgeId,
|
||||
TopologyNode,
|
||||
)
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
from shared.types.states.shared import SharedState
|
||||
from shared.types.tasks.common import BaseTaskData, TaskType
|
||||
from shared.types.worker.common import NodeStatus
|
||||
from shared.types.worker.instances import InstanceId, InstanceParams
|
||||
from shared.types.tasks.common import Task
|
||||
|
||||
|
||||
class ExternalCommand(BaseModel): ...
|
||||
@@ -49,52 +37,23 @@ class NodePerformanceProfileState(State[EventCategoryEnum.MutatesNodePerformance
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile]
|
||||
|
||||
|
||||
class DataPlaneNetworkState(State[EventCategoryEnum.MutatesDataPlaneState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = (
|
||||
EventCategoryEnum.MutatesDataPlaneState
|
||||
class TopologyState(State[EventCategoryEnum.MutatesTopologyState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = (
|
||||
EventCategoryEnum.MutatesTopologyState
|
||||
)
|
||||
topology: DataPlaneTopology = DataPlaneTopology(
|
||||
edge_base=DataPlaneEdgeAdapter, vertex_base=TypeAdapter(None)
|
||||
topology: Topology = Topology(
|
||||
edge_base=TypeAdapter(TopologyEdge), vertex_base=TypeAdapter(TopologyNode)
|
||||
)
|
||||
history: Sequence[OrphanedPartOfDataPlaneTopology] = []
|
||||
history: Sequence[OrphanedPartOfTopology] = []
|
||||
|
||||
def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ...
|
||||
def add_edge(self, edge: DataPlaneEdge) -> None: ...
|
||||
|
||||
|
||||
class ControlPlaneNetworkState(State[EventCategoryEnum.MutatesControlPlaneState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
|
||||
EventCategoryEnum.MutatesControlPlaneState
|
||||
)
|
||||
topology: ControlPlaneTopology = ControlPlaneTopology(
|
||||
edge_base=TypeAdapter(None), vertex_base=TypeAdapter(NodeStatus)
|
||||
)
|
||||
history: Sequence[OrphanedPartOfControlPlaneTopology] = []
|
||||
|
||||
def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ...
|
||||
def add_edge(self, edge: DataPlaneEdge) -> None: ...
|
||||
def delete_edge(self, edge_id: TopologyEdgeId) -> None: ...
|
||||
def add_edge(self, edge: TopologyEdge) -> None: ...
|
||||
|
||||
|
||||
class MasterState(SharedState):
|
||||
data_plane_network_state: DataPlaneNetworkState = DataPlaneNetworkState()
|
||||
control_plane_network_state: ControlPlaneNetworkState = ControlPlaneNetworkState()
|
||||
job_inbox: Queue[BaseTaskData[TaskType]] = Queue()
|
||||
job_outbox: Queue[BaseTaskData[TaskType]] = Queue()
|
||||
topology_state: TopologyState = TopologyState()
|
||||
task_inbox: Queue[Task] = Queue()
|
||||
task_outbox: Queue[Task] = Queue()
|
||||
cache_policy: CachePolicy[CachePolicyType] = CachePolicy[CachePolicyType](
|
||||
policy_type=CachePolicyType.KeepAll
|
||||
)
|
||||
|
||||
|
||||
def get_shard_assignments(
|
||||
inbox: Queue[ExternalCommand],
|
||||
outbox: Queue[ExternalCommand],
|
||||
resource_graph: ResourceGraph,
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
cache_policy: CachePolicy[CachePolicyType],
|
||||
) -> Mapping[InstanceId, InstanceParams]: ...
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
target_instances: Mapping[InstanceId, InstanceParams],
|
||||
) -> Sequence[BaseEvent[EventCategory]]: ...
|
||||
|
||||
@@ -5,13 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import EventCategoryEnum, State
|
||||
from shared.types.tasks.common import (
|
||||
Task,
|
||||
TaskId,
|
||||
TaskSagaEntry,
|
||||
TaskStatusType,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.tasks.common import Task, TaskId, TaskSagaEntry
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import BaseInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
@@ -28,7 +22,7 @@ class Tasks(State[EventCategoryEnum.MutatesTaskState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = (
|
||||
EventCategoryEnum.MutatesTaskState
|
||||
)
|
||||
tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
|
||||
class TaskSagas(State[EventCategoryEnum.MutatesTaskSagaState]):
|
||||
@@ -55,4 +49,4 @@ class SharedState(BaseModel):
|
||||
|
||||
def get_tasks_by_instance(
|
||||
self, instance_id: InstanceId
|
||||
) -> Sequence[Task[TaskType, TaskStatusType]]: ...
|
||||
) -> Sequence[Task]: ...
|
||||
|
||||
@@ -10,9 +10,9 @@ from shared.types.states.shared import SharedState
|
||||
from shared.types.worker.common import NodeStatus
|
||||
|
||||
|
||||
class NodeStatusState(State[EventCategoryEnum.MutatesControlPlaneState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
|
||||
EventCategoryEnum.MutatesControlPlaneState
|
||||
class NodeStatusState(State[EventCategoryEnum.MutatesRunnerStatus]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = (
|
||||
EventCategoryEnum.MutatesRunnerStatus
|
||||
)
|
||||
node_status: Mapping[NodeId, NodeStatus]
|
||||
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import ( # noqa: E402
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
)
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NewUUID
|
||||
from shared.types.worker.common import InstanceId
|
||||
@@ -20,35 +11,17 @@ class TaskId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
## TASK TYPES
|
||||
@final
|
||||
class TaskType(str, Enum):
|
||||
ChatCompletion = "ChatCompletion"
|
||||
|
||||
TaskTypeT = TypeVar("TaskTypeT", bound=TaskType, covariant=True)
|
||||
|
||||
## TASK STATUSES
|
||||
@final
|
||||
class TaskStatusFailedType(str, Enum):
|
||||
class TaskStatus(str, Enum):
|
||||
Pending = "Pending"
|
||||
Running = "Running"
|
||||
Complete = "Complete"
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
@final
|
||||
class TaskStatusCompleteType(str, Enum):
|
||||
Complete = "Complete"
|
||||
|
||||
|
||||
@final
|
||||
class TaskStatusOtherType(str, Enum):
|
||||
Pending = "Pending"
|
||||
Running = "Running"
|
||||
|
||||
|
||||
TaskStatusType = TaskStatusCompleteType | TaskStatusFailedType | TaskStatusOtherType
|
||||
TaskStatusTypeT = TypeVar("TaskStatusTypeT", bound=TaskStatusType)#, covariant=True)
|
||||
|
||||
|
||||
## Peripherals
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: str | None = None
|
||||
@@ -57,10 +30,12 @@ class ChatCompletionMessage(BaseModel):
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
class CompletionCreateParams(BaseModel):
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
task_type: Literal[TaskType.ChatCompletion] = TaskType.ChatCompletion
|
||||
model: str
|
||||
messages: list[ChatCompletionMessage]
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
logit_bias: dict[str, int] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
@@ -79,69 +54,14 @@ class CompletionCreateParams(BaseModel):
|
||||
user: str | None = None
|
||||
|
||||
|
||||
## Task Data is stored in task, one-to-one with task type
|
||||
|
||||
class BaseTaskData(BaseModel, Generic[TaskTypeT]): ...
|
||||
|
||||
@final
|
||||
class ChatCompletionTaskData(BaseTaskData[TaskType.ChatCompletion]):
|
||||
task_type: Literal[TaskType.ChatCompletion] = (
|
||||
TaskType.ChatCompletion
|
||||
)
|
||||
task_params: CompletionCreateParams
|
||||
|
||||
TaskData: TypeAlias = ChatCompletionTaskData
|
||||
|
||||
|
||||
## TASKS
|
||||
|
||||
class TaskArtifact[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel): ...
|
||||
|
||||
|
||||
@final
|
||||
class NoTaskArtifact[TaskTypeT: TaskType](TaskArtifact[TaskTypeT, TaskStatusOtherType]):
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class FailedTaskArtifact[TaskTypeT: TaskType](
|
||||
TaskArtifact[TaskTypeT, TaskStatusFailedType]
|
||||
):
|
||||
error_message: str
|
||||
|
||||
|
||||
@final
|
||||
class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel):
|
||||
task_status: TaskStatusTypeT
|
||||
task_artifact: TaskArtifact[TaskTypeT, TaskStatusTypeT]
|
||||
|
||||
|
||||
class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel):
|
||||
task_type: TaskTypeT
|
||||
task_data: TaskData # Really this should be BaseTaskData[TaskTypeT], but this causes a bunch of errors that I don't know how to fix yet.
|
||||
task_state: TaskState[TaskStatusTypeT, TaskTypeT]
|
||||
on_instance: InstanceId
|
||||
|
||||
|
||||
BaseTaskAnnotated = Annotated[
|
||||
Union[
|
||||
BaseTask[Literal[TaskType.ChatCompletion], TaskStatusType],
|
||||
],
|
||||
Field(discriminator="task_type"),
|
||||
]
|
||||
|
||||
BaseTaskParser: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter(
|
||||
BaseTaskAnnotated
|
||||
)
|
||||
class Task(BaseModel):
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
task_type: TaskType
|
||||
task_status: TaskStatus
|
||||
task_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class TaskSagaEntry(BaseModel):
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
@final
|
||||
class Task[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](
|
||||
BaseTask[TaskTypeT, TaskStatusTypeT]
|
||||
):
|
||||
task_id: TaskId
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Generic, Literal, TypeVar
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.tasks.common import ChatCompletionTaskData
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
|
||||
@@ -35,7 +35,7 @@ class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
|
||||
type: Literal[MessageType.ChatTask] = Field(
|
||||
default=MessageType.ChatTask, frozen=True
|
||||
)
|
||||
task_data: ChatCompletionTaskData
|
||||
task_data: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Generic, Literal, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from shared.types.events.events import InstanceId
|
||||
from shared.types.tasks.common import Task, TaskStatusType, TaskType
|
||||
from shared.types.tasks.common import Task
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
@@ -52,7 +52,7 @@ class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]):
|
||||
class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]):
|
||||
op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(default=RunnerOpType.CHAT_COMPLETION, frozen=True)
|
||||
runner_id: RunnerId
|
||||
task: Task[TaskType, TaskStatusType]
|
||||
task: Task
|
||||
|
||||
|
||||
# Aggregate all runner operations into a single, strictly-typed union for dispatching.
|
||||
|
||||
@@ -11,7 +11,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.tasks.common import ChatCompletionTaskData, CompletionCreateParams
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams
|
||||
from shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
ExitMessage,
|
||||
@@ -34,7 +34,7 @@ async def _mlx_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
task: ChatCompletionTaskData,
|
||||
task: ChatCompletionTaskParams,
|
||||
) -> AsyncGenerator[GenerationResponse]:
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[GenerationResponse | Exception | object] = asyncio.Queue()
|
||||
@@ -63,17 +63,15 @@ async def _mlx_generate(
|
||||
_ = loop.call_soon_threadsafe(queue.put_nowait, sentinel)
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
task_data: CompletionCreateParams = task.task_params
|
||||
|
||||
runner_print(f"task_data: {task_data}")
|
||||
runner_print(f"task_params: {task}")
|
||||
|
||||
prompt = await apply_chat_template(
|
||||
mlx_executor=mlx_executor,
|
||||
tokenizer=tokenizer,
|
||||
chat_task_data=task_data,
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
max_tokens = task.task_params.max_tokens or 100
|
||||
max_tokens = task.max_tokens or 100
|
||||
generation_fn = partial(_generate_tokens, prompt, max_tokens)
|
||||
|
||||
future = loop.run_in_executor(mlx_executor, generation_fn)
|
||||
@@ -120,10 +118,10 @@ async def main():
|
||||
while True:
|
||||
message: RunnerMessage = await runner_read_message()
|
||||
match message:
|
||||
case ChatTaskMessage(task_data=task_data):
|
||||
runner_print(f"received chat request: {task_data}")
|
||||
case ChatTaskMessage(task_data=task):
|
||||
runner_print(f"received chat request: {task}")
|
||||
# Ensure we have a chat-completion task subtype
|
||||
prompt = task_data.task_params.messages[0]
|
||||
prompt = task.messages[0]
|
||||
if prompt.content is not None and 'EXO RUNNER MUST FAIL' in prompt.content:
|
||||
raise Exception('Artificial runner exception - for testing purposes only.')
|
||||
|
||||
@@ -133,7 +131,7 @@ async def main():
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
task=task_data,
|
||||
task=task,
|
||||
):
|
||||
runner_write_response(generation_response)
|
||||
|
||||
|
||||
@@ -7,10 +7,8 @@ from typing import Any, Callable
|
||||
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionTaskData,
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
TaskStatusTypeT,
|
||||
TaskTypeT,
|
||||
)
|
||||
from shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
@@ -148,7 +146,7 @@ class RunnerSupervisor:
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
task: Task[TaskTypeT, TaskStatusTypeT],
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
"""
|
||||
@@ -159,12 +157,12 @@ class RunnerSupervisor:
|
||||
if not self.healthy:
|
||||
raise RuntimeError("Runner process was found to be dead")
|
||||
|
||||
task_data = task.task_data
|
||||
assert isinstance(task_data, ChatCompletionTaskData) # this is messy for now.
|
||||
task_params = task.task_params
|
||||
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process,
|
||||
message=ChatTaskMessage(
|
||||
task_data=task_data,
|
||||
task_data=task_params,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import uuid
|
||||
from logging import Logger, getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -11,13 +11,10 @@ from shared.types.models.common import ModelId
|
||||
from shared.types.states.worker import NodeStatusState, WorkerState
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionTaskData,
|
||||
CompletionCreateParams,
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
TaskArtifact,
|
||||
TaskId,
|
||||
TaskState,
|
||||
TaskStatusOtherType,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
@@ -32,12 +29,6 @@ from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.main import Worker
|
||||
|
||||
|
||||
class PendingStreamingTaskArtifact(
|
||||
TaskArtifact[Literal[TaskType.ChatCompletion], Literal[TaskStatusOtherType.Pending]]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_shard_meta():
|
||||
def _pipeline_shard_meta(
|
||||
@@ -97,35 +88,30 @@ def user_message():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completion_create_params(user_message: str) -> CompletionCreateParams:
|
||||
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
|
||||
"""Creates ChatCompletionParams with the given message"""
|
||||
return CompletionCreateParams(
|
||||
return ChatCompletionTaskParams(
|
||||
model="gpt-4",
|
||||
messages=[ChatCompletionMessage(role="user", content=user_message)],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_completion_task(completion_create_params: CompletionCreateParams) -> ChatCompletionTaskData:
|
||||
def chat_completion_task(completion_create_params: ChatCompletionTaskParams) -> Task:
|
||||
"""Creates a ChatCompletionTask directly for serdes testing"""
|
||||
return ChatCompletionTaskData(task_params=completion_create_params)
|
||||
return Task(task_id=TaskId(), instance_id=InstanceId(), task_type=TaskType.ChatCompletion, task_status=TaskStatus.Pending, task_params=completion_create_params)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_task(
|
||||
completion_create_params: CompletionCreateParams,
|
||||
) -> Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType]:
|
||||
completion_create_params: ChatCompletionTaskParams,
|
||||
) -> Task:
|
||||
"""Creates the final Task object"""
|
||||
return Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType](
|
||||
return Task(
|
||||
task_id=TaskId(),
|
||||
instance_id=InstanceId(),
|
||||
task_type=TaskType.ChatCompletion,
|
||||
task_data=ChatCompletionTaskData(
|
||||
task_params=completion_create_params
|
||||
),
|
||||
task_state=TaskState[TaskStatusOtherType, Literal[TaskType.ChatCompletion]](
|
||||
task_status=TaskStatusOtherType.Pending,
|
||||
task_artifact=PendingStreamingTaskArtifact(),
|
||||
),
|
||||
on_instance=InstanceId(),
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=completion_create_params,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable, TypeVar
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from shared.types.tasks.common import ChatCompletionTaskData
|
||||
from shared.types.tasks.common import Task
|
||||
from shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
RunnerMessageTypeAdapter,
|
||||
@@ -35,9 +35,9 @@ def test_supervisor_setup_message_serdes(
|
||||
|
||||
|
||||
def test_supervisor_task_message_serdes(
|
||||
chat_completion_task: ChatCompletionTaskData,
|
||||
chat_completion_task: Task,
|
||||
):
|
||||
task_message = ChatTaskMessage(
|
||||
task_data=chat_completion_task,
|
||||
task_data=chat_completion_task.task_params,
|
||||
)
|
||||
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import asyncio
|
||||
from typing import Callable, Literal
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionTaskData,
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
TaskStatusOtherType,
|
||||
TaskStatusType,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.worker.mlx import Host
|
||||
@@ -27,7 +25,7 @@ def user_message():
|
||||
async def test_supervisor_single_node_response(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_task: Task[TaskType, TaskStatusType],
|
||||
chat_task: Task,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -63,7 +61,7 @@ async def test_supervisor_single_node_response(
|
||||
async def test_supervisor_two_node_response(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_task: Task[TaskType, TaskStatusType],
|
||||
chat_task: Task,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
supervisor_0 = await RunnerSupervisor.create(
|
||||
@@ -117,7 +115,7 @@ async def test_supervisor_two_node_response(
|
||||
async def test_supervisor_early_stopping(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_task: Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType],
|
||||
chat_task: Task,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -129,16 +127,16 @@ async def test_supervisor_early_stopping(
|
||||
|
||||
max_tokens = 50
|
||||
assert chat_task.task_type == TaskType.ChatCompletion
|
||||
print(f'chat_task.task_data: {type(chat_task.task_data)}')
|
||||
assert isinstance(chat_task.task_data, ChatCompletionTaskData)
|
||||
task_data: ChatCompletionTaskData = chat_task.task_data
|
||||
print(f'chat_task.task_params: {chat_task.task_params}')
|
||||
assert isinstance(chat_task.task_params, ChatCompletionTaskParams)
|
||||
task_params: ChatCompletionTaskParams = chat_task.task_params
|
||||
|
||||
try:
|
||||
task_data.task_params.max_tokens = max_tokens
|
||||
task_params.max_tokens = max_tokens
|
||||
# Convert messages to a list to allow indexing, then update the first message's content
|
||||
messages = list(task_data.task_params.messages)
|
||||
messages = list(task_params.messages)
|
||||
messages[0].content = "Please count from 1 to 100"
|
||||
task_data.task_params.messages = messages
|
||||
task_params.messages = messages
|
||||
|
||||
full_response = ""
|
||||
count = 0
|
||||
@@ -167,7 +165,7 @@ async def test_supervisor_early_stopping(
|
||||
async def test_supervisor_handles_terminated_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_task: Task[TaskType, TaskStatusType],
|
||||
chat_task: Task,
|
||||
):
|
||||
"""Test that the supervisor handles a terminated runner"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -191,7 +189,7 @@ async def test_supervisor_handles_terminated_runner(
|
||||
async def test_supervisor_handles_killed_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_task: Task[TaskType, TaskStatusType],
|
||||
chat_task: Task,
|
||||
):
|
||||
"""Test that the supervisor handles a killed runner"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
|
||||
@@ -9,7 +9,7 @@ from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import TokenChunk, TokenChunkData
|
||||
from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.tasks.common import Task, TaskStatusType, TaskType
|
||||
from shared.types.tasks.common import Task
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.ops import (
|
||||
@@ -84,7 +84,7 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
assert len(events) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_task: Task[TaskType, TaskStatusType]):
|
||||
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_task: Task):
|
||||
worker, runner_id, _ = worker_with_assigned_runner
|
||||
|
||||
runner_up_op = RunnerUpOp(runner_id=runner_id)
|
||||
@@ -153,7 +153,7 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_op(
|
||||
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
|
||||
chat_task: Task[TaskType, TaskStatusType]):
|
||||
chat_task: Task):
|
||||
worker, runner_id, _ = worker_with_running_runner
|
||||
|
||||
execute_task_op = ExecuteTaskOp(
|
||||
@@ -187,10 +187,10 @@ async def test_execute_task_op(
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_fails(
|
||||
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
|
||||
chat_task: Task[TaskType, TaskStatusType]):
|
||||
chat_task: Task):
|
||||
worker, runner_id, _ = worker_with_running_runner
|
||||
|
||||
messages = chat_task.task_data.task_params.messages
|
||||
messages = chat_task.task_params.messages
|
||||
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
|
||||
|
||||
execute_task_op = ExecuteTaskOp(
|
||||
|
||||
Reference in New Issue
Block a user