Simplify Task type + merge control & data plane types into single type

This commit is contained in:
Seth Howes
2025-07-21 17:10:09 +01:00
committed by GitHub
parent 2f64e30dd1
commit d19aa4f95a
29 changed files with 235 additions and 513 deletions

View File

@@ -15,7 +15,7 @@ from mlx_lm.utils import load_model
from pydantic import RootModel
from engines.mlx.auto_parallel import auto_parallel
from shared.types.tasks.common import CompletionCreateParams
from shared.types.tasks.common import ChatCompletionTaskParams
from shared.types.worker.mlx import Host
from shared.types.worker.shards import ShardMeta
from worker.runner.communication import runner_print
@@ -96,7 +96,7 @@ def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWra
async def apply_chat_template(
mlx_executor: concurrent.futures.ThreadPoolExecutor,
tokenizer: TokenizerWrapper,
chat_task_data: CompletionCreateParams,
chat_task_data: ChatCompletionTaskParams,
) -> str:
loop: AbstractEventLoop = asyncio.get_running_loop()

View File

@@ -1,18 +1,16 @@
from typing import Protocol
from shared.types.graphs.topology import Topology
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
from shared.types.models.sources import ModelSource
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
from shared.types.worker.common import InstanceId
from shared.types.worker.downloads import DownloadProgress
from shared.types.worker.instances import Instance
class ControlPlaneAPI(Protocol):
def get_control_plane_topology(self) -> ControlPlaneTopology: ...
def get_data_plane_topology(self) -> DataPlaneTopology: ...
class ClusterAPI(Protocol):
def get_topology(self) -> Topology: ...
def list_instances(self) -> list[Instance]: ...

View File

@@ -54,13 +54,12 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
# What The Master Cares About
MasterEventCategories = (
Literal[EventCategoryEnum.MutatesControlPlaneState]
Literal[EventCategoryEnum.MutatesTopologyState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
| Literal[EventCategoryEnum.MutatesDataPlaneState]
)
@@ -119,13 +118,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
@app.get("/topology/control_plane")
def get_control_plane_topology():
return {"message": "Hello, World!"}
@app.get("/topology/data_plane")
def get_data_plane_topology():
@app.get("/topology")
def get_topology():
return {"message": "Hello, World!"}

23
master/placement.py Normal file
View File

@@ -0,0 +1,23 @@
from queue import Queue
from typing import Mapping, Sequence
from shared.types.events.common import BaseEvent, EventCategory
from shared.types.graphs.topology import Topology
from shared.types.states.master import CachePolicy, CachePolicyType
from shared.types.tasks.common import Task
from shared.types.worker.instances import InstanceId, InstanceParams
def get_instance_placement(
inbox: Queue[Task],
outbox: Queue[Task],
topology: Topology,
current_instances: Mapping[InstanceId, InstanceParams],
cache_policy: CachePolicy[CachePolicyType],
) -> Mapping[InstanceId, InstanceParams]: ...
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceParams],
target_instances: Mapping[InstanceId, InstanceParams],
) -> Sequence[BaseEvent[EventCategory]]: ...

View File

@@ -108,11 +108,8 @@ class AsyncStateManagerMapping(TypedDict):
MutatesTaskSagaState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesTaskSagaState]
]
MutatesControlPlaneState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesControlPlaneState]
]
MutatesDataPlaneState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesDataPlaneState]
MutatesTopologyState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesTopologyState]
]
MutatesRunnerStatus: AsyncStateManager[
Literal[EventCategoryEnum.MutatesRunnerStatus]

View File

@@ -7,8 +7,7 @@ from shared.types.events.common import EventCategoryEnum, State
class SyncStateManagerMapping(TypedDict):
MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]]
MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]]
MutatesControlPlaneState: State[Literal[EventCategoryEnum.MutatesControlPlaneState]]
MutatesDataPlaneState: State[Literal[EventCategoryEnum.MutatesDataPlaneState]]
MutatesTopologyState: State[Literal[EventCategoryEnum.MutatesTopologyState]]
MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]]
MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]]
MutatesNodePerformanceState: State[

View File

@@ -33,7 +33,7 @@ class _EdgeWrapper[EdgeTypeT, EdgeIdT]:
edge_data: EdgeData[EdgeTypeT]
class NetworkXGraph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
class Graph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
edge_base: TypeAdapter[EdgeTypeT]
vertex_base: TypeAdapter[VertexTypeT]

View File

@@ -2,10 +2,10 @@ from typing import Literal
from pydantic import BaseModel
from shared.types.tasks.common import CompletionCreateParams, TaskId
from shared.types.tasks.common import ChatCompletionTaskParams, TaskId
class ChatTask(BaseModel):
task_id: TaskId
kind: Literal["chat"] = "chat"
task_data: CompletionCreateParams
task_data: ChatCompletionTaskParams

View File

@@ -60,13 +60,10 @@ class NodePerformanceEventTypes(str, Enum):
NodePerformanceMeasured = "NodePerformanceMeasured"
class DataPlaneEventTypes(str, Enum):
DataPlaneEdgeCreated = "DataPlaneEdgeCreated"
DataPlaneEdgeReplacedAtomically = "DataPlaneEdgeReplacedAtomically"
DataPlaneEdgeDeleted = "DataPlaneEdgeDeleted"
class ControlPlaneEventTypes(str, Enum):
class TopologyEventTypes(str, Enum):
TopologyEdgeCreated = "TopologyEdgeCreated"
TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically"
TopologyEdgeDeleted = "TopologyEdgeDeleted"
WorkerConnected = "WorkerConnected"
WorkerStatusUpdated = "WorkerStatusUpdated"
WorkerDisconnected = "WorkerDisconnected"
@@ -84,8 +81,7 @@ EVENT_TYPE_ENUMS = [
InstanceEventTypes,
RunnerStatusEventTypes,
NodePerformanceEventTypes,
DataPlaneEventTypes,
ControlPlaneEventTypes,
TopologyEventTypes,
TimerEventTypes,
TaskSagaEventTypes,
]
@@ -98,8 +94,7 @@ EventTypes = (
| InstanceEventTypes
| RunnerStatusEventTypes
| NodePerformanceEventTypes
| ControlPlaneEventTypes
| DataPlaneEventTypes
| TopologyEventTypes
| TimerEventTypes
| TaskSagaEventTypes
)
@@ -114,18 +109,17 @@ class EventCategoryEnum(StrEnum):
MutatesRunnerStatus = "MutatesRunnerStatus"
MutatesInstanceState = "MutatesInstanceState"
MutatesNodePerformanceState = "MutatesNodePerformanceState"
MutatesControlPlaneState = "MutatesControlPlaneState"
MutatesDataPlaneState = "MutatesDataPlaneState"
MutatesTopologyState = "MutatesTopologyState"
EventCategory = (
Literal[EventCategoryEnum.MutatesControlPlaneState]
Literal[EventCategoryEnum.MutatesTopologyState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
| Literal[EventCategoryEnum.MutatesDataPlaneState]
| Literal[EventCategoryEnum.MutatesTopologyState]
)
EventCategories = FrozenSet[EventCategory]

View File

@@ -6,8 +6,6 @@ from shared.types.common import NodeId
from shared.types.events.chunks import GenerationChunk
from shared.types.events.common import (
BaseEvent,
ControlPlaneEventTypes,
DataPlaneEventTypes,
EventCategoryEnum,
InstanceEventTypes,
NodePerformanceEventTypes,
@@ -15,33 +13,23 @@ from shared.types.events.common import (
StreamingEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
TopologyEventTypes,
)
from shared.types.networking.control_plane import (
ControlPlaneEdgeId,
ControlPlaneEdgeType,
)
from shared.types.networking.data_plane import (
DataPlaneEdge,
DataPlaneEdgeId,
DataPlaneEdgeProfile,
from shared.types.graphs.topology import (
TopologyEdge,
TopologyEdgeId,
TopologyEdgeProfile,
TopologyNode,
)
from shared.types.profiling.common import NodePerformanceProfile
from shared.types.tasks.common import (
BaseTaskData,
TaskId,
TaskState,
TaskStatusOtherType,
TaskStatusType,
TaskType,
)
from shared.types.tasks.common import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.runners import RunnerId, RunnerStatus
TaskEvent = BaseEvent[EventCategoryEnum.MutatesTaskState]
InstanceEvent = BaseEvent[EventCategoryEnum.MutatesInstanceState]
ControlPlaneEvent = BaseEvent[EventCategoryEnum.MutatesControlPlaneState]
DataPlaneEvent = BaseEvent[EventCategoryEnum.MutatesDataPlaneState]
TopologyEvent = BaseEvent[EventCategoryEnum.MutatesTopologyState]
NodePerformanceEvent = BaseEvent[EventCategoryEnum.MutatesNodePerformanceState]
@@ -49,9 +37,7 @@ class TaskCreated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEven
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
task_id: TaskId
task_data: BaseTaskData[TaskType]
task_state: TaskState[Literal[TaskStatusOtherType.Pending], TaskType]
on_instance: InstanceId
task: Task
# Covers Cancellation Of Task, Non-Cancelled Tasks Perist
@@ -64,7 +50,8 @@ class TaskDeleted(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEven
class TaskStateUpdated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskStateUpdated]]):
event_type: Literal[TaskEventTypes.TaskStateUpdated] = TaskEventTypes.TaskStateUpdated
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
task_state: TaskState[TaskStatusType, TaskType]
task_id: TaskId
task_status: TaskStatus
class InstanceCreated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceCreated]]):
@@ -130,23 +117,23 @@ class NodePerformanceMeasured(BaseEvent[EventCategoryEnum.MutatesNodePerformance
node_profile: NodePerformanceProfile
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerConnected]]):
event_type: Literal[ControlPlaneEventTypes.WorkerConnected] = ControlPlaneEventTypes.WorkerConnected
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
edge: DataPlaneEdge
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerConnected]]):
event_type: Literal[TopologyEventTypes.WorkerConnected] = TopologyEventTypes.WorkerConnected
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
edge: TopologyEdge
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerStatusUpdated]]):
event_type: Literal[ControlPlaneEventTypes.WorkerStatusUpdated] = ControlPlaneEventTypes.WorkerStatusUpdated
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerStatusUpdated]]):
event_type: Literal[TopologyEventTypes.WorkerStatusUpdated] = TopologyEventTypes.WorkerStatusUpdated
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
node_id: NodeId
node_state: NodeStatus
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesControlPlaneState, Literal[ControlPlaneEventTypes.WorkerDisconnected]]):
event_type: Literal[ControlPlaneEventTypes.WorkerDisconnected] = ControlPlaneEventTypes.WorkerDisconnected
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = EventCategoryEnum.MutatesControlPlaneState
vertex_id: ControlPlaneEdgeId
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerDisconnected]]):
event_type: Literal[TopologyEventTypes.WorkerDisconnected] = TopologyEventTypes.WorkerDisconnected
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
vertex_id: NodeId
class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[StreamingEventTypes.ChunkGenerated]]):
@@ -156,23 +143,23 @@ class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[Strea
chunk: GenerationChunk
class DataPlaneEdgeCreated(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeCreated]]):
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeCreated] = DataPlaneEventTypes.DataPlaneEdgeCreated
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
vertex: ControlPlaneEdgeType
class TopologyEdgeCreated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeCreated]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeCreated] = TopologyEventTypes.TopologyEdgeCreated
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
vertex: TopologyNode
class DataPlaneEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically]]):
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically] = DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
edge_id: DataPlaneEdgeId
edge_profile: DataPlaneEdgeProfile
class TopologyEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically] = TopologyEventTypes.TopologyEdgeReplacedAtomically
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
edge_id: TopologyEdgeId
edge_profile: TopologyEdgeProfile
class DataPlaneEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesDataPlaneState, Literal[DataPlaneEventTypes.DataPlaneEdgeDeleted]]):
event_type: Literal[DataPlaneEventTypes.DataPlaneEdgeDeleted] = DataPlaneEventTypes.DataPlaneEdgeDeleted
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = EventCategoryEnum.MutatesDataPlaneState
edge_id: DataPlaneEdgeId
class TopologyEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeDeleted]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeDeleted] = TopologyEventTypes.TopologyEdgeDeleted
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
edge_id: TopologyEdgeId
"""
TEST_EVENT_CATEGORIES_TYPE = FrozenSet[

View File

@@ -6,8 +6,6 @@ from pydantic import Field, TypeAdapter
from shared.constants import get_error_reporting_message
from shared.types.events.common import (
BaseEvent,
ControlPlaneEventTypes,
DataPlaneEventTypes,
EventCategories,
EventTypes,
InstanceEventTypes,
@@ -16,12 +14,10 @@ from shared.types.events.common import (
StreamingEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
TopologyEventTypes,
)
from shared.types.events.events import (
ChunkGenerated,
DataPlaneEdgeCreated,
DataPlaneEdgeDeleted,
DataPlaneEdgeReplacedAtomically,
InstanceCreated,
InstanceDeleted,
InstanceReplacedAtomically,
@@ -32,6 +28,9 @@ from shared.types.events.events import (
TaskCreated,
TaskDeleted,
TaskStateUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TopologyEdgeReplacedAtomically,
WorkerConnected,
WorkerDisconnected,
WorkerStatusUpdated,
@@ -59,13 +58,13 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
ControlPlaneEventTypes.WorkerConnected: WorkerConnected,
ControlPlaneEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
ControlPlaneEventTypes.WorkerDisconnected: WorkerDisconnected,
TopologyEventTypes.WorkerConnected: WorkerConnected,
TopologyEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
TopologyEventTypes.WorkerDisconnected: WorkerDisconnected,
StreamingEventTypes.ChunkGenerated: ChunkGenerated,
DataPlaneEventTypes.DataPlaneEdgeCreated: DataPlaneEdgeCreated,
DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically: DataPlaneEdgeReplacedAtomically,
DataPlaneEventTypes.DataPlaneEdgeDeleted: DataPlaneEdgeDeleted,
TopologyEventTypes.TopologyEdgeCreated: TopologyEdgeCreated,
TopologyEventTypes.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
TopologyEventTypes.TopologyEdgeDeleted: TopologyEdgeDeleted,
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
}
@@ -115,9 +114,9 @@ Event = (
| WorkerStatusUpdated
| WorkerDisconnected
| ChunkGenerated
| DataPlaneEdgeCreated
| DataPlaneEdgeReplacedAtomically
| DataPlaneEdgeDeleted
| TopologyEdgeCreated
| TopologyEdgeReplacedAtomically
| TopologyEdgeDeleted
| MLXInferenceSagaPrepare
| MLXInferenceSagaStartPrepare
)

View File

@@ -1,17 +0,0 @@
from collections.abc import Mapping
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
from shared.types.profiling.common import NodePerformanceProfile
class ResourceGraph(BaseModel): ...
def get_graph_of_compute_resources(
control_plane_topology: ControlPlaneTopology,
data_plane_topology: DataPlaneTopology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> ResourceGraph: ...

View File

@@ -0,0 +1,48 @@
from pydantic import BaseModel, IPvAnyAddress
from shared.graphs import Graph
from shared.types.common import NewUUID, NodeId
from shared.types.profiling.common import NodePerformanceProfile
class TopologyEdgeId(NewUUID):
pass
class TopologyEdgeProfile(BaseModel):
throughput: float
latency: float
jitter: float
class TopologyEdge(BaseModel):
source_ip: IPvAnyAddress
sink_ip: IPvAnyAddress
edge_profile: TopologyEdgeProfile
class TopologyNode(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
class Topology(
Graph[
TopologyEdge,
TopologyNode,
TopologyEdgeId,
NodeId,
]
):
pass
class OrphanedPartOfTopology(
Graph[
TopologyEdge,
TopologyNode,
TopologyEdgeId,
NodeId,
]
):
pass

View File

@@ -1,11 +0,0 @@
from typing import TypeAlias
from shared.types.common import NewUUID, NodeId
from shared.types.graphs.common import Edge
class ControlPlaneEdgeId(NewUUID):
pass
ControlPlaneEdgeType: TypeAlias = Edge[None, ControlPlaneEdgeId, NodeId]

View File

@@ -1,68 +0,0 @@
from enum import Enum
from typing import Annotated, Literal, TypeVar, Union, final
from pydantic import BaseModel, Field, IPvAnyAddress, TypeAdapter
from shared.types.common import NewUUID, NodeId
from shared.types.graphs.common import Edge
class DataPlaneEdgeId(NewUUID):
pass
class AddressingProtocol(str, Enum):
IPvAnyAddress = "IPvAnyAddress"
class ApplicationProtocol(str, Enum):
MLX = "MLX"
AdP = TypeVar("AdP", bound=AddressingProtocol)
ApP = TypeVar("ApP", bound=ApplicationProtocol)
@final
class DataPlaneEdgeProfile(BaseModel):
throughput: float
latency: float
jitter: float
class CommonDataPlaneEdgeData(BaseModel):
edge_data_transfer_rate: DataPlaneEdgeProfile | None = None
class MlxEdgeMetadata(BaseModel):
source_ip: IPvAnyAddress
sink_ip: IPvAnyAddress
class BaseDataPlaneEdgeData[AdP: AddressingProtocol, ApP: ApplicationProtocol](
BaseModel
):
addressing_protocol: AdP
application_protocol: ApP
common_data: CommonDataPlaneEdgeData
class MlxEdge(
BaseDataPlaneEdgeData[AddressingProtocol.IPvAnyAddress, ApplicationProtocol.MLX]
):
addressing_protocol: Literal[AddressingProtocol.IPvAnyAddress] = (
AddressingProtocol.IPvAnyAddress
)
application_protocol: Literal[ApplicationProtocol.MLX] = ApplicationProtocol.MLX
mlx_metadata: MlxEdgeMetadata
DataPlaneEdgeData = Union[MlxEdge]
_DataPlaneEdgeData = Annotated[
DataPlaneEdgeData,
Field(discriminator="addressing_protocol"),
]
DataPlaneEdgeAdapter: TypeAdapter[DataPlaneEdgeData] = TypeAdapter(_DataPlaneEdgeData)
DataPlaneEdge = Edge[DataPlaneEdgeData, DataPlaneEdgeId, NodeId]

View File

@@ -1,29 +0,0 @@
from typing import Callable, NewType, Protocol
from shared.types.networking.control_plane import (
ControlPlaneEdgeId,
ControlPlaneEdgeType,
)
TopicName = NewType("TopicName", str)
PubSubMessageHandler = Callable[[TopicName, object], None]
NodeConnectedHandler = Callable[
[
ControlPlaneEdgeId,
ControlPlaneEdgeType,
],
None,
]
NodeDisconnectedHandler = Callable[[ControlPlaneEdgeId], None]
class DiscoveryService(Protocol):
def on_node_connected(self, handler: NodeConnectedHandler) -> None: ...
def on_node_disconnected(self, handler: NodeDisconnectedHandler) -> None: ...
class PubSubService(Protocol):
def on_message_received(
self, topic_name: TopicName, handler: PubSubMessageHandler
) -> None: ...

View File

@@ -1,45 +0,0 @@
from shared.graphs.networkx import NetworkXGraph
from shared.types.common import NodeId
from shared.types.networking.control_plane import ControlPlaneEdgeId
from shared.types.networking.data_plane import (
DataPlaneEdgeData,
DataPlaneEdgeId,
)
from shared.types.worker.common import NodeStatus
class DataPlaneTopology(
NetworkXGraph[
DataPlaneEdgeData,
None,
DataPlaneEdgeId,
NodeId,
]
):
pass
class OrphanedPartOfDataPlaneTopology(
NetworkXGraph[
DataPlaneEdgeData,
None,
DataPlaneEdgeId,
NodeId,
]
):
pass
class ControlPlaneTopology(NetworkXGraph[None, NodeStatus, ControlPlaneEdgeId, NodeId]):
pass
class OrphanedPartOfControlPlaneTopology(
NetworkXGraph[
None,
NodeStatus,
ControlPlaneEdgeId,
NodeId,
]
):
pass

View File

@@ -6,29 +6,17 @@ from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, TypeAdapter
from shared.types.common import NodeId
from shared.types.events.common import (
BaseEvent,
EventCategory,
EventCategoryEnum,
State,
)
from shared.types.graphs.resource_graph import ResourceGraph
from shared.types.networking.data_plane import (
DataPlaneEdge,
DataPlaneEdgeAdapter,
DataPlaneEdgeId,
)
from shared.types.networking.topology import (
ControlPlaneTopology,
DataPlaneTopology,
OrphanedPartOfControlPlaneTopology,
OrphanedPartOfDataPlaneTopology,
from shared.types.events.common import EventCategoryEnum, State
from shared.types.graphs.topology import (
OrphanedPartOfTopology,
Topology,
TopologyEdge,
TopologyEdgeId,
TopologyNode,
)
from shared.types.profiling.common import NodePerformanceProfile
from shared.types.states.shared import SharedState
from shared.types.tasks.common import BaseTaskData, TaskType
from shared.types.worker.common import NodeStatus
from shared.types.worker.instances import InstanceId, InstanceParams
from shared.types.tasks.common import Task
class ExternalCommand(BaseModel): ...
@@ -49,52 +37,23 @@ class NodePerformanceProfileState(State[EventCategoryEnum.MutatesNodePerformance
node_profiles: Mapping[NodeId, NodePerformanceProfile]
class DataPlaneNetworkState(State[EventCategoryEnum.MutatesDataPlaneState]):
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = (
EventCategoryEnum.MutatesDataPlaneState
class TopologyState(State[EventCategoryEnum.MutatesTopologyState]):
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = (
EventCategoryEnum.MutatesTopologyState
)
topology: DataPlaneTopology = DataPlaneTopology(
edge_base=DataPlaneEdgeAdapter, vertex_base=TypeAdapter(None)
topology: Topology = Topology(
edge_base=TypeAdapter(TopologyEdge), vertex_base=TypeAdapter(TopologyNode)
)
history: Sequence[OrphanedPartOfDataPlaneTopology] = []
history: Sequence[OrphanedPartOfTopology] = []
def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ...
def add_edge(self, edge: DataPlaneEdge) -> None: ...
class ControlPlaneNetworkState(State[EventCategoryEnum.MutatesControlPlaneState]):
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
EventCategoryEnum.MutatesControlPlaneState
)
topology: ControlPlaneTopology = ControlPlaneTopology(
edge_base=TypeAdapter(None), vertex_base=TypeAdapter(NodeStatus)
)
history: Sequence[OrphanedPartOfControlPlaneTopology] = []
def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ...
def add_edge(self, edge: DataPlaneEdge) -> None: ...
def delete_edge(self, edge_id: TopologyEdgeId) -> None: ...
def add_edge(self, edge: TopologyEdge) -> None: ...
class MasterState(SharedState):
data_plane_network_state: DataPlaneNetworkState = DataPlaneNetworkState()
control_plane_network_state: ControlPlaneNetworkState = ControlPlaneNetworkState()
job_inbox: Queue[BaseTaskData[TaskType]] = Queue()
job_outbox: Queue[BaseTaskData[TaskType]] = Queue()
topology_state: TopologyState = TopologyState()
task_inbox: Queue[Task] = Queue()
task_outbox: Queue[Task] = Queue()
cache_policy: CachePolicy[CachePolicyType] = CachePolicy[CachePolicyType](
policy_type=CachePolicyType.KeepAll
)
def get_shard_assignments(
inbox: Queue[ExternalCommand],
outbox: Queue[ExternalCommand],
resource_graph: ResourceGraph,
current_instances: Mapping[InstanceId, InstanceParams],
cache_policy: CachePolicy[CachePolicyType],
) -> Mapping[InstanceId, InstanceParams]: ...
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceParams],
target_instances: Mapping[InstanceId, InstanceParams],
) -> Sequence[BaseEvent[EventCategory]]: ...

View File

@@ -5,13 +5,7 @@ from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.events.common import EventCategoryEnum, State
from shared.types.tasks.common import (
Task,
TaskId,
TaskSagaEntry,
TaskStatusType,
TaskType,
)
from shared.types.tasks.common import Task, TaskId, TaskSagaEntry
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import BaseInstance
from shared.types.worker.runners import RunnerId, RunnerStatus
@@ -28,7 +22,7 @@ class Tasks(State[EventCategoryEnum.MutatesTaskState]):
event_category: Literal[EventCategoryEnum.MutatesTaskState] = (
EventCategoryEnum.MutatesTaskState
)
tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] = {}
tasks: Mapping[TaskId, Task] = {}
class TaskSagas(State[EventCategoryEnum.MutatesTaskSagaState]):
@@ -55,4 +49,4 @@ class SharedState(BaseModel):
def get_tasks_by_instance(
self, instance_id: InstanceId
) -> Sequence[Task[TaskType, TaskStatusType]]: ...
) -> Sequence[Task]: ...

View File

@@ -10,9 +10,9 @@ from shared.types.states.shared import SharedState
from shared.types.worker.common import NodeStatus
class NodeStatusState(State[EventCategoryEnum.MutatesControlPlaneState]):
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
EventCategoryEnum.MutatesControlPlaneState
class NodeStatusState(State[EventCategoryEnum.MutatesRunnerStatus]):
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = (
EventCategoryEnum.MutatesRunnerStatus
)
node_status: Mapping[NodeId, NodeStatus]

View File

@@ -1,16 +1,7 @@
from enum import Enum
from typing import ( # noqa: E402
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
Union,
final,
)
from typing import Any, Literal
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel
from shared.types.common import NewUUID
from shared.types.worker.common import InstanceId
@@ -20,35 +11,17 @@ class TaskId(NewUUID):
pass
## TASK TYPES
@final
class TaskType(str, Enum):
ChatCompletion = "ChatCompletion"
TaskTypeT = TypeVar("TaskTypeT", bound=TaskType, covariant=True)
## TASK STATUSES
@final
class TaskStatusFailedType(str, Enum):
class TaskStatus(str, Enum):
Pending = "Pending"
Running = "Running"
Complete = "Complete"
Failed = "Failed"
@final
class TaskStatusCompleteType(str, Enum):
Complete = "Complete"
@final
class TaskStatusOtherType(str, Enum):
Pending = "Pending"
Running = "Running"
TaskStatusType = TaskStatusCompleteType | TaskStatusFailedType | TaskStatusOtherType
TaskStatusTypeT = TypeVar("TaskStatusTypeT", bound=TaskStatusType)#, covariant=True)
## Peripherals
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: str | None = None
@@ -57,10 +30,12 @@ class ChatCompletionMessage(BaseModel):
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None
class CompletionCreateParams(BaseModel):
class ChatCompletionTaskParams(BaseModel):
task_type: Literal[TaskType.ChatCompletion] = TaskType.ChatCompletion
model: str
messages: list[ChatCompletionMessage]
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
@@ -79,69 +54,14 @@ class CompletionCreateParams(BaseModel):
user: str | None = None
## Task Data is stored in task, one-to-one with task type
class BaseTaskData(BaseModel, Generic[TaskTypeT]): ...
@final
class ChatCompletionTaskData(BaseTaskData[TaskType.ChatCompletion]):
task_type: Literal[TaskType.ChatCompletion] = (
TaskType.ChatCompletion
)
task_params: CompletionCreateParams
TaskData: TypeAlias = ChatCompletionTaskData
## TASKS
class TaskArtifact[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel): ...
@final
class NoTaskArtifact[TaskTypeT: TaskType](TaskArtifact[TaskTypeT, TaskStatusOtherType]):
pass
@final
class FailedTaskArtifact[TaskTypeT: TaskType](
TaskArtifact[TaskTypeT, TaskStatusFailedType]
):
error_message: str
@final
class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel):
task_status: TaskStatusTypeT
task_artifact: TaskArtifact[TaskTypeT, TaskStatusTypeT]
class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel):
task_type: TaskTypeT
task_data: TaskData # Really this should be BaseTaskData[TaskTypeT], but this causes a bunch of errors that I don't know how to fix yet.
task_state: TaskState[TaskStatusTypeT, TaskTypeT]
on_instance: InstanceId
BaseTaskAnnotated = Annotated[
Union[
BaseTask[Literal[TaskType.ChatCompletion], TaskStatusType],
],
Field(discriminator="task_type"),
]
BaseTaskParser: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter(
BaseTaskAnnotated
)
class Task(BaseModel):
task_id: TaskId
instance_id: InstanceId
task_type: TaskType
task_status: TaskStatus
task_params: ChatCompletionTaskParams
class TaskSagaEntry(BaseModel):
task_id: TaskId
instance_id: InstanceId
@final
class Task[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](
BaseTask[TaskTypeT, TaskStatusTypeT]
):
task_id: TaskId

View File

@@ -4,7 +4,7 @@ from typing import Annotated, Generic, Literal, TypeVar
from pydantic import BaseModel, Field, TypeAdapter
from shared.openai_compat import FinishReason
from shared.types.tasks.common import ChatCompletionTaskData
from shared.types.tasks.common import ChatCompletionTaskParams
from shared.types.worker.mlx import Host
from shared.types.worker.shards import ShardMetadata
@@ -35,7 +35,7 @@ class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
type: Literal[MessageType.ChatTask] = Field(
default=MessageType.ChatTask, frozen=True
)
task_data: ChatCompletionTaskData
task_data: ChatCompletionTaskParams
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):

View File

@@ -4,7 +4,7 @@ from typing import Annotated, Generic, Literal, TypeVar, Union
from pydantic import BaseModel, Field
from shared.types.events.events import InstanceId
from shared.types.tasks.common import Task, TaskStatusType, TaskType
from shared.types.tasks.common import Task
from shared.types.worker.common import RunnerId
from shared.types.worker.mlx import Host
from shared.types.worker.shards import ShardMetadata
@@ -52,7 +52,7 @@ class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]):
class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]):
op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(default=RunnerOpType.CHAT_COMPLETION, frozen=True)
runner_id: RunnerId
task: Task[TaskType, TaskStatusType]
task: Task
# Aggregate all runner operations into a single, strictly-typed union for dispatching.

View File

@@ -11,7 +11,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx
from shared.openai_compat import FinishReason
from shared.types.tasks.common import ChatCompletionTaskData, CompletionCreateParams
from shared.types.tasks.common import ChatCompletionTaskParams
from shared.types.worker.commands_runner import (
ChatTaskMessage,
ExitMessage,
@@ -34,7 +34,7 @@ async def _mlx_generate(
model: nn.Module,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
task: ChatCompletionTaskData,
task: ChatCompletionTaskParams,
) -> AsyncGenerator[GenerationResponse]:
loop = asyncio.get_running_loop()
queue: asyncio.Queue[GenerationResponse | Exception | object] = asyncio.Queue()
@@ -63,17 +63,15 @@ async def _mlx_generate(
_ = loop.call_soon_threadsafe(queue.put_nowait, sentinel)
# Currently we support chat-completion tasks only.
task_data: CompletionCreateParams = task.task_params
runner_print(f"task_data: {task_data}")
runner_print(f"task_params: {task}")
prompt = await apply_chat_template(
mlx_executor=mlx_executor,
tokenizer=tokenizer,
chat_task_data=task_data,
chat_task_data=task,
)
max_tokens = task.task_params.max_tokens or 100
max_tokens = task.max_tokens or 100
generation_fn = partial(_generate_tokens, prompt, max_tokens)
future = loop.run_in_executor(mlx_executor, generation_fn)
@@ -120,10 +118,10 @@ async def main():
while True:
message: RunnerMessage = await runner_read_message()
match message:
case ChatTaskMessage(task_data=task_data):
runner_print(f"received chat request: {task_data}")
case ChatTaskMessage(task_data=task):
runner_print(f"received chat request: {task}")
# Ensure we have a chat-completion task subtype
prompt = task_data.task_params.messages[0]
prompt = task.messages[0]
if prompt.content is not None and 'EXO RUNNER MUST FAIL' in prompt.content:
raise Exception('Artificial runner exception - for testing purposes only.')
@@ -133,7 +131,7 @@ async def main():
model=model,
tokenizer=tokenizer,
sampler=sampler,
task=task_data,
task=task,
):
runner_write_response(generation_response)

View File

@@ -7,10 +7,8 @@ from typing import Any, Callable
from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData
from shared.types.tasks.common import (
ChatCompletionTaskData,
ChatCompletionTaskParams,
Task,
TaskStatusTypeT,
TaskTypeT,
)
from shared.types.worker.commands_runner import (
ChatTaskMessage,
@@ -148,7 +146,7 @@ class RunnerSupervisor:
async def stream_response(
self,
task: Task[TaskTypeT, TaskStatusTypeT],
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
) -> AsyncGenerator[GenerationChunk]:
"""
@@ -159,12 +157,12 @@ class RunnerSupervisor:
if not self.healthy:
raise RuntimeError("Runner process was found to be dead")
task_data = task.task_data
assert isinstance(task_data, ChatCompletionTaskData) # this is messy for now.
task_params = task.task_params
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
await supervisor_write_message(
proc=self.runner_process,
message=ChatTaskMessage(
task_data=task_data,
task_data=task_params,
),
)

View File

@@ -2,7 +2,7 @@ import asyncio
import uuid
from logging import Logger, getLogger
from pathlib import Path
from typing import Callable, Literal
from typing import Callable
import pytest
@@ -11,13 +11,10 @@ from shared.types.models.common import ModelId
from shared.types.states.worker import NodeStatusState, WorkerState
from shared.types.tasks.common import (
ChatCompletionMessage,
ChatCompletionTaskData,
CompletionCreateParams,
ChatCompletionTaskParams,
Task,
TaskArtifact,
TaskId,
TaskState,
TaskStatusOtherType,
TaskStatus,
TaskType,
)
from shared.types.worker.common import InstanceId, NodeStatus
@@ -32,12 +29,6 @@ from shared.types.worker.shards import PipelineShardMetadata
from worker.main import Worker
class PendingStreamingTaskArtifact(
TaskArtifact[Literal[TaskType.ChatCompletion], Literal[TaskStatusOtherType.Pending]]
):
pass
@pytest.fixture
def pipeline_shard_meta():
def _pipeline_shard_meta(
@@ -97,35 +88,30 @@ def user_message():
@pytest.fixture
def completion_create_params(user_message: str) -> CompletionCreateParams:
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
"""Creates ChatCompletionParams with the given message"""
return CompletionCreateParams(
return ChatCompletionTaskParams(
model="gpt-4",
messages=[ChatCompletionMessage(role="user", content=user_message)],
stream=True,
)
@pytest.fixture
def chat_completion_task(completion_create_params: CompletionCreateParams) -> ChatCompletionTaskData:
def chat_completion_task(completion_create_params: ChatCompletionTaskParams) -> Task:
"""Creates a ChatCompletionTask directly for serdes testing"""
return ChatCompletionTaskData(task_params=completion_create_params)
return Task(task_id=TaskId(), instance_id=InstanceId(), task_type=TaskType.ChatCompletion, task_status=TaskStatus.Pending, task_params=completion_create_params)
@pytest.fixture
def chat_task(
completion_create_params: CompletionCreateParams,
) -> Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType]:
completion_create_params: ChatCompletionTaskParams,
) -> Task:
"""Creates the final Task object"""
return Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType](
return Task(
task_id=TaskId(),
instance_id=InstanceId(),
task_type=TaskType.ChatCompletion,
task_data=ChatCompletionTaskData(
task_params=completion_create_params
),
task_state=TaskState[TaskStatusOtherType, Literal[TaskType.ChatCompletion]](
task_status=TaskStatusOtherType.Pending,
task_artifact=PendingStreamingTaskArtifact(),
),
on_instance=InstanceId(),
task_status=TaskStatus.Pending,
task_params=completion_create_params,
)
@pytest.fixture

View File

@@ -2,7 +2,7 @@ from typing import Callable, TypeVar
from pydantic import BaseModel, TypeAdapter
from shared.types.tasks.common import ChatCompletionTaskData
from shared.types.tasks.common import Task
from shared.types.worker.commands_runner import (
ChatTaskMessage,
RunnerMessageTypeAdapter,
@@ -35,9 +35,9 @@ def test_supervisor_setup_message_serdes(
def test_supervisor_task_message_serdes(
chat_completion_task: ChatCompletionTaskData,
chat_completion_task: Task,
):
task_message = ChatTaskMessage(
task_data=chat_completion_task,
task_data=chat_completion_task.task_params,
)
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)

View File

@@ -1,15 +1,13 @@
import asyncio
from typing import Callable, Literal
from typing import Callable
import pytest
from shared.openai_compat import FinishReason
from shared.types.events.chunks import TokenChunk
from shared.types.tasks.common import (
ChatCompletionTaskData,
ChatCompletionTaskParams,
Task,
TaskStatusOtherType,
TaskStatusType,
TaskType,
)
from shared.types.worker.mlx import Host
@@ -27,7 +25,7 @@ def user_message():
async def test_supervisor_single_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_task: Task[TaskType, TaskStatusType],
chat_task: Task,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -63,7 +61,7 @@ async def test_supervisor_single_node_response(
async def test_supervisor_two_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_task: Task[TaskType, TaskStatusType],
chat_task: Task,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
supervisor_0 = await RunnerSupervisor.create(
@@ -117,7 +115,7 @@ async def test_supervisor_two_node_response(
async def test_supervisor_early_stopping(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_task: Task[Literal[TaskType.ChatCompletion], TaskStatusOtherType],
chat_task: Task,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -129,16 +127,16 @@ async def test_supervisor_early_stopping(
max_tokens = 50
assert chat_task.task_type == TaskType.ChatCompletion
print(f'chat_task.task_data: {type(chat_task.task_data)}')
assert isinstance(chat_task.task_data, ChatCompletionTaskData)
task_data: ChatCompletionTaskData = chat_task.task_data
print(f'chat_task.task_params: {chat_task.task_params}')
assert isinstance(chat_task.task_params, ChatCompletionTaskParams)
task_params: ChatCompletionTaskParams = chat_task.task_params
try:
task_data.task_params.max_tokens = max_tokens
task_params.max_tokens = max_tokens
# Convert messages to a list to allow indexing, then update the first message's content
messages = list(task_data.task_params.messages)
messages = list(task_params.messages)
messages[0].content = "Please count from 1 to 100"
task_data.task_params.messages = messages
task_params.messages = messages
full_response = ""
count = 0
@@ -167,7 +165,7 @@ async def test_supervisor_early_stopping(
async def test_supervisor_handles_terminated_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_task: Task[TaskType, TaskStatusType],
chat_task: Task,
):
"""Test that the supervisor handles a terminated runner"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -191,7 +189,7 @@ async def test_supervisor_handles_terminated_runner(
async def test_supervisor_handles_killed_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_task: Task[TaskType, TaskStatusType],
chat_task: Task,
):
"""Test that the supervisor handles a killed runner"""
model_shard_meta = pipeline_shard_meta(1, 0)

View File

@@ -9,7 +9,7 @@ from shared.types.common import NodeId
from shared.types.events.chunks import TokenChunk, TokenChunkData
from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated
from shared.types.events.registry import Event
from shared.types.tasks.common import Task, TaskStatusType, TaskType
from shared.types.tasks.common import Task
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance
from shared.types.worker.ops import (
@@ -84,7 +84,7 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
assert len(events) == 0
@pytest.mark.asyncio
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_task: Task[TaskType, TaskStatusType]):
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_task: Task):
worker, runner_id, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
@@ -153,7 +153,7 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_task: Task[TaskType, TaskStatusType]):
chat_task: Task):
worker, runner_id, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(
@@ -187,10 +187,10 @@ async def test_execute_task_op(
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_task: Task[TaskType, TaskStatusType]):
chat_task: Task):
worker, runner_id, _ = worker_with_running_runner
messages = chat_task.task_data.task_params.messages
messages = chat_task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
execute_task_op = ExecuteTaskOp(