Matt's interfaces

Added interfaces for chunks, worker, runner, supervisor, resourcemonitor, etc.
This commit is contained in:
Matt Beton
2025-07-07 16:42:52 +01:00
committed by GitHub
parent 367e76c8fa
commit 03a1cf59a6
18 changed files with 407 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
from typing import Protocol
from shared.types.models.common import Model, ModelId
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
from shared.types.models.sources import ModelSource
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
from shared.types.worker.common import InstanceId
@@ -21,7 +22,7 @@ class ControlPlaneAPI(Protocol):
def remove_instance(self, instance_id: InstanceId) -> None: ...
def get_model_data(self, model_id: ModelId) -> Model: ...
def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
def download_model(self, model_id: ModelId, model_source: ModelSource) -> None: ...

View File

@@ -66,6 +66,7 @@ only-include = ["pyproject.toml", "README.md"]
[tool.basedpyright]
typeCheckingMode = "strict"
failOnWarnings = true
stubPath = "stubs"
reportAny = "error"
reportUnknownVariableType = "error"

10
shared/types/api.py Normal file
View File

@@ -0,0 +1,10 @@
from typing import Literal
from pydantic import BaseModel
from openai.types.chat.completion_create_params import CompletionCreateParams
from shared.types.tasks.common import TaskId
class ChatTask(BaseModel):
task_id: TaskId
kind: Literal["chat"] = "chat"
task_data: CompletionCreateParams

View File

@@ -0,0 +1,89 @@
from typing import Any, Literal, TypeVar, Generic, Annotated
from collections.abc import AsyncGenerator
from enum import Enum
from pydantic import BaseModel, Field, TypeAdapter
from shared.types.tasks.common import TaskId
from shared.types.models.common import ModelId
from shared.openai import FinishReason
class ChunkType(str, Enum):
token = 'token'
image = 'image'
ChunkT = TypeVar('ChunkT', bound=ChunkType)
class BaseChunk(BaseModel, Generic[ChunkT]):
task_id: TaskId
idx: int
model: ModelId
###
class TokenChunkData(BaseModel):
text: str
token_id: int
finish_reason: FinishReason | None = None
class ImageChunkData(BaseModel):
data: bytes
###
class TokenChunk(BaseChunk[ChunkType.token]):
chunk_data: TokenChunkData
chunk_type: Literal[ChunkType.token] = Field(
default=ChunkType.token, frozen=True
)
class ImageChunk(BaseChunk[ChunkType.image]):
chunk_data: ImageChunkData
chunk_type: Literal[ChunkType.image] = Field(
default=ChunkType.image, frozen=True
)
###
GenerationChunk = Annotated[
TokenChunk | ImageChunk,
Field(discriminator="chunk_type")
]
GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk)
# my_chunk: dict[str, Any] = TokenChunk(
# task_id=TaskId('nicerid'),
# idx=0,
# chunk_data=TokenChunkData(
# text='hello',
# token_id=12,
# ),
# chunk_type=ChunkType.token,
# model='llama-3.1',
# ).model_dump()
# print(my_chunk)
# restored = GenerationChunkTypeAdapter.validate_python(my_chunk)
# print(restored)
#### OpenAI API Interfaces ###
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
OpenAIResponse = ChatCompletion | ChatCompletionChunk ## Currently we only support chat completions
def send_task(task: Any) -> AsyncGenerator[GenerationChunk]:
"""
This is the 'command' - turns the task into an event and pushes to the event queue.
Tokens are then read off the event queue and pushed back to the api via an AsyncGenerator.
"""
...
def parse_chunk_to_openai_response(chunk: GenerationChunk) -> OpenAIResponse:
...
async def handle_task(task: Any) -> AsyncGenerator[OpenAIResponse]:
## In our api call function, we will do:
generator: AsyncGenerator[GenerationChunk] = send_task(task)
async for chunk in generator:
yield parse_chunk_to_openai_response(chunk)

View File

@@ -16,8 +16,8 @@ from pydantic import BaseModel, Field, TypeAdapter, model_validator
from shared.types.common import NewUUID, NodeId
class EventId(NewUUID):
pass
class EventId(NewUUID): pass
class TimerId(NewUUID): pass
class MLXEventTypes(str, Enum):
@@ -67,6 +67,9 @@ class TimerEventTypes(str, Enum):
TimerCreated = "TimerCreated"
TimerFired = "TimerFired"
class ResourceEventTypes(str, Enum):
ResourceProfiled = "ResourceProfiled"
EventTypes = Union[
TaskEventTypes,
@@ -78,6 +81,7 @@ EventTypes = Union[
DataPlaneEventTypes,
TimerEventTypes,
MLXEventTypes,
ResourceEventTypes,
]
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)

View File

@@ -4,7 +4,8 @@ from typing import Any, Literal, Tuple
from pydantic import BaseModel
from shared.types.common import NewUUID, NodeId
from shared.types.common import NodeId
from shared.types.events.common import TimerId
from shared.types.events.common import (
ControlPlaneEventTypes,
DataPlaneEventTypes,
@@ -16,6 +17,7 @@ from shared.types.events.common import (
StreamingEventTypes,
TaskEventTypes,
TimerEventTypes,
ResourceEventTypes,
)
from shared.types.networking.control_plane import (
ControlPlaneEdgeId,
@@ -41,14 +43,7 @@ from shared.types.tasks.common import (
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceData, InstanceStatus
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
class RequestId(NewUUID):
pass
class TimerId(NewUUID):
pass
from shared.types.profiling.common import ProfiledResourceName
class TimerData(BaseModel):
@@ -205,3 +200,10 @@ class TimerScheduled(Event[TimerEventTypes.TimerCreated]):
class TimerFired(Event[TimerEventTypes.TimerFired]):
event_type: Literal[TimerEventTypes.TimerFired] = TimerEventTypes.TimerFired
timer_data: TimerData
class ResourceProfiled(Event[ResourceEventTypes.ResourceProfiled]):
event_type: Literal[ResourceEventTypes.ResourceProfiled] = (
ResourceEventTypes.ResourceProfiled
)
resource_name: ProfiledResourceName
resource_profile: NodePerformanceProfile

View File

@@ -1,18 +1,3 @@
from typing import Sequence, final
from pydantic import BaseModel
from shared.types.common import NewUUID
from shared.types.models.metadata import ModelMetadata
from shared.types.models.sources import ModelSource
class ModelId(NewUUID):
pass
@final
class Model(BaseModel):
model_id: ModelId
model_sources: Sequence[ModelSource]
model_metadata: ModelMetadata
class ModelId(NewUUID): pass

View File

@@ -0,0 +1,18 @@
from typing import final, Sequence
from pydantic import BaseModel, TypeAdapter
from shared.types.models.common import ModelId
from shared.types.models.metadata import ModelMetadata
from shared.types.models.sources import ModelSource
@final
# Concerned by the naming here; model could also be an instance of a model.
class ModelInfo(BaseModel):
model_id: ModelId
model_sources: Sequence[ModelSource]
model_metadata: ModelMetadata
ModelIdAdapter: TypeAdapter[ModelId] = TypeAdapter(ModelId)

View File

@@ -1,4 +1,47 @@
from pydantic import BaseModel
from typing import Annotated, Literal, Coroutine, Generic, TypeVar
from enum import Enum
from abc import ABC
from pydantic import BaseModel, Field, TypeAdapter
class NodePerformanceProfile(BaseModel): ...
class ProfiledResourceName(str, Enum):
memory = 'memory'
system = 'system'
ProfiledResourceT = TypeVar(name='ProfiledResourceT', bound=ProfiledResourceName)
class BasePerformanceProfile(BaseModel, Generic[ProfiledResourceT]):
"""
Details a single resource (or resource type) that is being monitored by the resource monitor.
"""
pass
class MemoryPerformanceProfile(BasePerformanceProfile[ProfiledResourceName.memory]):
resource_name: Literal[ProfiledResourceName.memory] = Field(
default=ProfiledResourceName.memory, frozen=True
)
ram_total: int
ram_used: int
swap_total: int
swap_used: int
class NetworkInterfaceInfo(BaseModel):
name: str
ip_address: str
type: str
class SystemPerformanceProfile(BasePerformanceProfile[ProfiledResourceName.system]):
resource_name: Literal[ProfiledResourceName.system] = Field(
default=ProfiledResourceName.system, frozen=True
)
model_id: str
chip_id: str
memory: int
network_interfaces: list[NetworkInterfaceInfo] = Field(default_factory=list)
NodePerformanceProfile = Annotated[
MemoryPerformanceProfile | SystemPerformanceProfile,
Field(discriminator="resource_name")
]
NodePerformanceProfileTypeAdapter: TypeAdapter[NodePerformanceProfile] = TypeAdapter(NodePerformanceProfile)

View File

@@ -1,6 +1,6 @@
from collections.abc import Mapping
from enum import Enum
from typing import Annotated, Generic, Literal, TypeVar, Union
from typing import Annotated, Generic, Literal, TypeVar
import openai.types.chat as openai
from pydantic import BaseModel, Field, TypeAdapter

View File

@@ -0,0 +1,91 @@
from typing import Annotated, Generic, Literal, TypeVar
from enum import Enum
from pydantic import BaseModel, Field, TypeAdapter
from shared.types.api import ChatTask
from shared.types.worker.shards import ShardMeta
from shared.types.worker.mlx import Host
from shared.openai import FinishReason
## Messages passed TO the runner
class MessageType(str, Enum):
Setup = 'setup'
ChatTask = "chat_task"
Exit = 'exit'
MT = TypeVar(name='MT', bound=MessageType)
class BaseRunnerMessage(BaseModel, Generic[MT]):
pass
class SetupMessage(BaseRunnerMessage[MessageType.Setup]):
type: Literal[MessageType.Setup] = Field(
default=MessageType.Setup, frozen=True
)
model_shard_meta: ShardMeta
hosts: list[Host]
class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
type: Literal[MessageType.ChatTask] = Field(
default=MessageType.ChatTask, frozen=True
)
task: ChatTask
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):
type: Literal[MessageType.Exit] = Field(
default=MessageType.Exit, frozen=True
)
RunnerMessage = Annotated[
SetupMessage | ChatTaskMessage | ExitMessage,
Field(discriminator="type")
]
RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage)
## Responses passed FROM the runner
class RunnerResponseType(str, Enum):
GenerationResponse = "generation_response"
FinishedResponse = "finished_response"
PrintResponse = "print_response"
ErrorResponse = "error_response"
RRT = TypeVar(name='RRT', bound=RunnerResponseType)
class BaseRunnerResponse(BaseModel, Generic[RRT]):
pass
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
type: Literal[RunnerResponseType.GenerationResponse] = Field(
default=RunnerResponseType.GenerationResponse, frozen=True
)
text: str
token: int
# logprobs: Optional[list[float]] = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
class PrintResponse(BaseRunnerResponse[RunnerResponseType.PrintResponse]):
type: Literal[RunnerResponseType.PrintResponse] = Field(
default=RunnerResponseType.PrintResponse, frozen=True
)
text: str
class FinishedResponse(BaseRunnerResponse[RunnerResponseType.FinishedResponse]):
type: Literal[RunnerResponseType.FinishedResponse] = Field(
default=RunnerResponseType.FinishedResponse, frozen=True
)
class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
type: Literal[RunnerResponseType.ErrorResponse] = Field(
default=RunnerResponseType.ErrorResponse, frozen=True
)
error_type: str
error_message: str
traceback: str | None = None
RunnerResponse = Annotated[
GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
Field(discriminator="type")
]
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)

View File

@@ -2,7 +2,6 @@ from enum import Enum
from shared.types.common import NewUUID
class InstanceId(NewUUID):
pass
@@ -14,4 +13,4 @@ class RunnerId(NewUUID):
class NodeStatus(str, Enum):
Idle = "Idle"
Running = "Running"
Paused = "Paused"
Paused = "Paused"

View File

@@ -15,7 +15,7 @@ from pydantic import BaseModel, Field, PositiveInt
from shared.types.common import NodeId
from shared.types.models.common import ModelId
from shared.types.models.sources import ModelSource
from shared.types.worker.shards import ShardData, ShardType
from shared.types.worker.shards import ShardMeta
class DownloadProgressData(BaseModel):
@@ -80,6 +80,6 @@ DownloadEffectHandler = Callable[
def download_shard(
model_id: ModelId,
model_source: ModelSource,
shard_data: ShardData[ShardType],
shard_meta: ShardMeta,
effect_handlers: Sequence[DownloadEffectHandler],
) -> None: ...

View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel, field_validator
# TODO: Is this the right place for this? Host is consumed by worker, but typically stored in the master
class Host(BaseModel):
host: str
port: int
@field_validator('port')
def check_port(cls, v: int) -> int:
if not (0 <= v <= 65535):
raise ValueError("Port must be between 0 and 65535")
return v

View File

@@ -0,0 +1,55 @@
from abc import ABC
from collections.abc import Coroutine
import asyncio
from shared.types.events.events import ResourceProfiledEvent
from shared.types.profiling.common import NodePerformanceProfile, MemoryPerformanceProfile, SystemPerformanceProfile
class EventLog:
def append(self, event: ResourceProfiledEvent) -> None:
...
class ResourceCollector(ABC):
"""
Details a single resource (or resource type) that is being monitored by the resource monitor.
"""
def __init__(self, name: str):
self.name = name
async def collect(self) -> NodePerformanceProfile:
...
class SystemResourceCollector(ResourceCollector):
def __init__(self):
super().__init__('system')
async def collect(self) -> SystemPerformanceProfile:
...
class MemoryResourceCollector(ResourceCollector):
def __init__(self):
super().__init__('memory')
async def collect(self) -> MemoryPerformanceProfile:
...
class ResourceMonitor:
def __init__(self, event_outbox: EventLog):
self.event_outbox: EventLog = event_outbox
self.collectors: list[ResourceCollector] = [
SystemResourceCollector(),
MemoryResourceCollector(),
]
async def collect(self) -> list[NodePerformanceProfile]:
tasks: list[Coroutine[None, None, NodePerformanceProfile]] = [
collector.collect() for collector in self.collectors
]
return await asyncio.gather(*tasks)
async def collect_and_publish(self) -> None:
profiles = await self.collect()
for profile in profiles:
self.event_outbox.append(profile.to_event())

View File

@@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Generic, Literal, TypeVar
from typing import Generic, Literal, TypeVar, Self
from pydantic import BaseModel, model_validator
@@ -8,7 +8,7 @@ from shared.types.common import NodeId
from shared.types.models.common import ModelId
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
from shared.types.worker.shards import ShardData, ShardType
from shared.types.worker.shards import BaseModelShardMeta, PartitionStrategyT
class RunnerStateType(str, Enum):
@@ -55,13 +55,17 @@ class RunnerData(BaseModel):
)
class RunnerPlacement(BaseModel):
# Runner placement must be consistent in its partitioning strategy across all shards.
# Using a generic type parameter enforces this constraint at type-checking time.
class RunnerPlacement(BaseModel, Generic[PartitionStrategyT]):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, ShardData[ShardType]]
runner_to_shard: Mapping[RunnerId, BaseModelShardMeta[PartitionStrategyT]]
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]
@model_validator(mode="after")
def validate_runners_exist(self) -> "RunnerPlacement":
def validate_runners_exist(self) -> Self:
for runners in self.node_to_runner.values():
for runner_id in runners:
if runner_id not in self.runner_to_shard:

View File

@@ -1,15 +1,47 @@
from enum import Enum
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Annotated, Literal
from pydantic import BaseModel
from pydantic import BaseModel, DirectoryPath, Field, TypeAdapter
from shared.types.common import NodeId
from shared.types.models.common import ModelId
class PartitionStrategy(str, Enum):
pipeline = 'pipeline'
PartitionStrategyT = TypeVar(name='PartitionStrategyT', bound=PartitionStrategy)
class BaseModelShardMeta(BaseModel, Generic[PartitionStrategyT]):
"""
Defines a specific shard of the model that is ready to be run on a device.
Replaces previous `Shard` object.
"""
device_rank: int
world_size: int
model_id: ModelId
model_path: DirectoryPath # pydantic DirectoryPath ensures that the directory exists.
class PipelineShardMeta(BaseModelShardMeta[PartitionStrategy.pipeline]):
"""
Pipeline parallelism shard meta.
"""
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
default=PartitionStrategy.pipeline, frozen=True
)
start_layer: Annotated[int, Field(ge=0)]
end_layer: Annotated[int, Field(ge=0)]
ShardMeta = Annotated[
PipelineShardMeta,
Field(discriminator="partition_strategy")
]
ShardMetaAdapter: TypeAdapter[ShardMeta] = TypeAdapter(ShardMeta)
class ShardType(str, Enum):
PipelineParallel = "PipelineParallel"
ShardTypeT = TypeVar("ShardTypeT", bound=ShardType)
class ShardData(BaseModel, Generic[ShardTypeT]):
shard_type: ShardTypeT
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
"""
A shard placement is the description of a model distributed across a set of nodes.
The Generic[PartitionStrategyT] enforces that the shard assignments all use the same partition strategy.
"""
model_id: ModelId
shard_assignments: dict[NodeId, BaseModelShardMeta[PartitionStrategyT]]

8
shared/utils.py Normal file
View File

@@ -0,0 +1,8 @@
from typing import Any, Type, TypeVar
T = TypeVar('T')
def ensure_type(obj: Any, expected_type: Type[T]) -> T:
if not isinstance(obj, expected_type):
raise TypeError(f"Expected {expected_type}, got {type(obj)}")
return obj