mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Matt's interfaces
Added interfaces for chunks, worker, runner, supervisor, resourcemonitor, etc.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Protocol
|
||||
|
||||
from shared.types.models.common import Model, ModelId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.model import ModelInfo
|
||||
from shared.types.models.sources import ModelSource
|
||||
from shared.types.networking.topology import ControlPlaneTopology, DataPlaneTopology
|
||||
from shared.types.worker.common import InstanceId
|
||||
@@ -21,7 +22,7 @@ class ControlPlaneAPI(Protocol):
|
||||
|
||||
def remove_instance(self, instance_id: InstanceId) -> None: ...
|
||||
|
||||
def get_model_data(self, model_id: ModelId) -> Model: ...
|
||||
def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
|
||||
|
||||
def download_model(self, model_id: ModelId, model_source: ModelSource) -> None: ...
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ only-include = ["pyproject.toml", "README.md"]
|
||||
[tool.basedpyright]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
stubPath = "stubs"
|
||||
|
||||
reportAny = "error"
|
||||
reportUnknownVariableType = "error"
|
||||
|
||||
10
shared/types/api.py
Normal file
10
shared/types/api.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from shared.types.tasks.common import TaskId
|
||||
|
||||
class ChatTask(BaseModel):
|
||||
task_id: TaskId
|
||||
kind: Literal["chat"] = "chat"
|
||||
task_data: CompletionCreateParams
|
||||
89
shared/types/events/chunks.py
Normal file
89
shared/types/events/chunks.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Any, Literal, TypeVar, Generic, Annotated
|
||||
from collections.abc import AsyncGenerator
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.types.tasks.common import TaskId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.openai import FinishReason
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
token = 'token'
|
||||
image = 'image'
|
||||
|
||||
ChunkT = TypeVar('ChunkT', bound=ChunkType)
|
||||
|
||||
class BaseChunk(BaseModel, Generic[ChunkT]):
|
||||
task_id: TaskId
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
###
|
||||
|
||||
class TokenChunkData(BaseModel):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
class ImageChunkData(BaseModel):
|
||||
data: bytes
|
||||
|
||||
###
|
||||
|
||||
class TokenChunk(BaseChunk[ChunkType.token]):
|
||||
chunk_data: TokenChunkData
|
||||
chunk_type: Literal[ChunkType.token] = Field(
|
||||
default=ChunkType.token, frozen=True
|
||||
)
|
||||
|
||||
class ImageChunk(BaseChunk[ChunkType.image]):
|
||||
chunk_data: ImageChunkData
|
||||
chunk_type: Literal[ChunkType.image] = Field(
|
||||
default=ChunkType.image, frozen=True
|
||||
)
|
||||
|
||||
###
|
||||
|
||||
GenerationChunk = Annotated[
|
||||
TokenChunk | ImageChunk,
|
||||
Field(discriminator="chunk_type")
|
||||
]
|
||||
GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk)
|
||||
|
||||
# my_chunk: dict[str, Any] = TokenChunk(
|
||||
# task_id=TaskId('nicerid'),
|
||||
# idx=0,
|
||||
# chunk_data=TokenChunkData(
|
||||
# text='hello',
|
||||
# token_id=12,
|
||||
# ),
|
||||
# chunk_type=ChunkType.token,
|
||||
# model='llama-3.1',
|
||||
# ).model_dump()
|
||||
# print(my_chunk)
|
||||
# restored = GenerationChunkTypeAdapter.validate_python(my_chunk)
|
||||
# print(restored)
|
||||
|
||||
#### OpenAI API Interfaces ###
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
OpenAIResponse = ChatCompletion | ChatCompletionChunk ## Currently we only support chat completions
|
||||
|
||||
def send_task(task: Any) -> AsyncGenerator[GenerationChunk]:
|
||||
"""
|
||||
This is the 'command' - turns the task into an event and pushes to the event queue.
|
||||
Tokens are then read off the event queue and pushed back to the api via an AsyncGenerator.
|
||||
"""
|
||||
...
|
||||
|
||||
def parse_chunk_to_openai_response(chunk: GenerationChunk) -> OpenAIResponse:
|
||||
...
|
||||
|
||||
async def handle_task(task: Any) -> AsyncGenerator[OpenAIResponse]:
|
||||
## In our api call function, we will do:
|
||||
generator: AsyncGenerator[GenerationChunk] = send_task(task)
|
||||
|
||||
async for chunk in generator:
|
||||
yield parse_chunk_to_openai_response(chunk)
|
||||
@@ -16,8 +16,8 @@ from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
|
||||
|
||||
class EventId(NewUUID):
|
||||
pass
|
||||
class EventId(NewUUID): pass
|
||||
class TimerId(NewUUID): pass
|
||||
|
||||
|
||||
class MLXEventTypes(str, Enum):
|
||||
@@ -67,6 +67,9 @@ class TimerEventTypes(str, Enum):
|
||||
TimerCreated = "TimerCreated"
|
||||
TimerFired = "TimerFired"
|
||||
|
||||
class ResourceEventTypes(str, Enum):
|
||||
ResourceProfiled = "ResourceProfiled"
|
||||
|
||||
|
||||
EventTypes = Union[
|
||||
TaskEventTypes,
|
||||
@@ -78,6 +81,7 @@ EventTypes = Union[
|
||||
DataPlaneEventTypes,
|
||||
TimerEventTypes,
|
||||
MLXEventTypes,
|
||||
ResourceEventTypes,
|
||||
]
|
||||
|
||||
EventTypeT = TypeVar("EventTypeT", bound=EventTypes)
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Any, Literal, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import TimerId
|
||||
from shared.types.events.common import (
|
||||
ControlPlaneEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
@@ -16,6 +17,7 @@ from shared.types.events.common import (
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TimerEventTypes,
|
||||
ResourceEventTypes,
|
||||
)
|
||||
from shared.types.networking.control_plane import (
|
||||
ControlPlaneEdgeId,
|
||||
@@ -41,14 +43,7 @@ from shared.types.tasks.common import (
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceData, InstanceStatus
|
||||
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
|
||||
|
||||
|
||||
class RequestId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class TimerId(NewUUID):
|
||||
pass
|
||||
from shared.types.profiling.common import ProfiledResourceName
|
||||
|
||||
|
||||
class TimerData(BaseModel):
|
||||
@@ -205,3 +200,10 @@ class TimerScheduled(Event[TimerEventTypes.TimerCreated]):
|
||||
class TimerFired(Event[TimerEventTypes.TimerFired]):
|
||||
event_type: Literal[TimerEventTypes.TimerFired] = TimerEventTypes.TimerFired
|
||||
timer_data: TimerData
|
||||
|
||||
class ResourceProfiled(Event[ResourceEventTypes.ResourceProfiled]):
|
||||
event_type: Literal[ResourceEventTypes.ResourceProfiled] = (
|
||||
ResourceEventTypes.ResourceProfiled
|
||||
)
|
||||
resource_name: ProfiledResourceName
|
||||
resource_profile: NodePerformanceProfile
|
||||
@@ -1,18 +1,3 @@
|
||||
from typing import Sequence, final
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NewUUID
|
||||
from shared.types.models.metadata import ModelMetadata
|
||||
from shared.types.models.sources import ModelSource
|
||||
|
||||
|
||||
class ModelId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Model(BaseModel):
|
||||
model_id: ModelId
|
||||
model_sources: Sequence[ModelSource]
|
||||
model_metadata: ModelMetadata
|
||||
class ModelId(NewUUID): pass
|
||||
18
shared/types/models/model.py
Normal file
18
shared/types/models/model.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import final, Sequence
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.metadata import ModelMetadata
|
||||
from shared.types.models.sources import ModelSource
|
||||
|
||||
|
||||
@final
|
||||
# Concerned by the naming here; model could also be an instance of a model.
|
||||
class ModelInfo(BaseModel):
|
||||
model_id: ModelId
|
||||
model_sources: Sequence[ModelSource]
|
||||
model_metadata: ModelMetadata
|
||||
|
||||
|
||||
ModelIdAdapter: TypeAdapter[ModelId] = TypeAdapter(ModelId)
|
||||
@@ -1,4 +1,47 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated, Literal, Coroutine, Generic, TypeVar
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
|
||||
class NodePerformanceProfile(BaseModel): ...
|
||||
class ProfiledResourceName(str, Enum):
|
||||
memory = 'memory'
|
||||
system = 'system'
|
||||
|
||||
ProfiledResourceT = TypeVar(name='ProfiledResourceT', bound=ProfiledResourceName)
|
||||
|
||||
class BasePerformanceProfile(BaseModel, Generic[ProfiledResourceT]):
|
||||
"""
|
||||
Details a single resource (or resource type) that is being monitored by the resource monitor.
|
||||
"""
|
||||
pass
|
||||
|
||||
class MemoryPerformanceProfile(BasePerformanceProfile[ProfiledResourceName.memory]):
|
||||
resource_name: Literal[ProfiledResourceName.memory] = Field(
|
||||
default=ProfiledResourceName.memory, frozen=True
|
||||
)
|
||||
ram_total: int
|
||||
ram_used: int
|
||||
swap_total: int
|
||||
swap_used: int
|
||||
|
||||
class NetworkInterfaceInfo(BaseModel):
|
||||
name: str
|
||||
ip_address: str
|
||||
type: str
|
||||
|
||||
class SystemPerformanceProfile(BasePerformanceProfile[ProfiledResourceName.system]):
|
||||
resource_name: Literal[ProfiledResourceName.system] = Field(
|
||||
default=ProfiledResourceName.system, frozen=True
|
||||
)
|
||||
model_id: str
|
||||
chip_id: str
|
||||
memory: int
|
||||
network_interfaces: list[NetworkInterfaceInfo] = Field(default_factory=list)
|
||||
|
||||
NodePerformanceProfile = Annotated[
|
||||
MemoryPerformanceProfile | SystemPerformanceProfile,
|
||||
Field(discriminator="resource_name")
|
||||
]
|
||||
|
||||
NodePerformanceProfileTypeAdapter: TypeAdapter[NodePerformanceProfile] = TypeAdapter(NodePerformanceProfile)
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Literal, TypeVar, Union
|
||||
from typing import Annotated, Generic, Literal, TypeVar
|
||||
|
||||
import openai.types.chat as openai
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
91
shared/types/worker/commands_runner.py
Normal file
91
shared/types/worker/commands_runner.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Annotated, Generic, Literal, TypeVar
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.types.api import ChatTask
|
||||
from shared.types.worker.shards import ShardMeta
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.openai import FinishReason
|
||||
|
||||
## Messages passed TO the runner
|
||||
|
||||
class MessageType(str, Enum):
|
||||
Setup = 'setup'
|
||||
ChatTask = "chat_task"
|
||||
Exit = 'exit'
|
||||
|
||||
MT = TypeVar(name='MT', bound=MessageType)
|
||||
|
||||
class BaseRunnerMessage(BaseModel, Generic[MT]):
|
||||
pass
|
||||
|
||||
class SetupMessage(BaseRunnerMessage[MessageType.Setup]):
|
||||
type: Literal[MessageType.Setup] = Field(
|
||||
default=MessageType.Setup, frozen=True
|
||||
)
|
||||
model_shard_meta: ShardMeta
|
||||
hosts: list[Host]
|
||||
|
||||
class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
|
||||
type: Literal[MessageType.ChatTask] = Field(
|
||||
default=MessageType.ChatTask, frozen=True
|
||||
)
|
||||
task: ChatTask
|
||||
|
||||
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):
|
||||
type: Literal[MessageType.Exit] = Field(
|
||||
default=MessageType.Exit, frozen=True
|
||||
)
|
||||
|
||||
RunnerMessage = Annotated[
|
||||
SetupMessage | ChatTaskMessage | ExitMessage,
|
||||
Field(discriminator="type")
|
||||
]
|
||||
RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage)
|
||||
|
||||
## Responses passed FROM the runner
|
||||
|
||||
class RunnerResponseType(str, Enum):
|
||||
GenerationResponse = "generation_response"
|
||||
FinishedResponse = "finished_response"
|
||||
PrintResponse = "print_response"
|
||||
ErrorResponse = "error_response"
|
||||
|
||||
RRT = TypeVar(name='RRT', bound=RunnerResponseType)
|
||||
|
||||
class BaseRunnerResponse(BaseModel, Generic[RRT]):
|
||||
pass
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
|
||||
type: Literal[RunnerResponseType.GenerationResponse] = Field(
|
||||
default=RunnerResponseType.GenerationResponse, frozen=True
|
||||
)
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: Optional[list[float]] = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
class PrintResponse(BaseRunnerResponse[RunnerResponseType.PrintResponse]):
|
||||
type: Literal[RunnerResponseType.PrintResponse] = Field(
|
||||
default=RunnerResponseType.PrintResponse, frozen=True
|
||||
)
|
||||
text: str
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse[RunnerResponseType.FinishedResponse]):
|
||||
type: Literal[RunnerResponseType.FinishedResponse] = Field(
|
||||
default=RunnerResponseType.FinishedResponse, frozen=True
|
||||
)
|
||||
|
||||
class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
|
||||
type: Literal[RunnerResponseType.ErrorResponse] = Field(
|
||||
default=RunnerResponseType.ErrorResponse, frozen=True
|
||||
)
|
||||
error_type: str
|
||||
error_message: str
|
||||
traceback: str | None = None
|
||||
|
||||
RunnerResponse = Annotated[
|
||||
GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
|
||||
Field(discriminator="type")
|
||||
]
|
||||
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)
|
||||
@@ -2,7 +2,6 @@ from enum import Enum
|
||||
|
||||
from shared.types.common import NewUUID
|
||||
|
||||
|
||||
class InstanceId(NewUUID):
|
||||
pass
|
||||
|
||||
@@ -14,4 +13,4 @@ class RunnerId(NewUUID):
|
||||
class NodeStatus(str, Enum):
|
||||
Idle = "Idle"
|
||||
Running = "Running"
|
||||
Paused = "Paused"
|
||||
Paused = "Paused"
|
||||
@@ -15,7 +15,7 @@ from pydantic import BaseModel, Field, PositiveInt
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.sources import ModelSource
|
||||
from shared.types.worker.shards import ShardData, ShardType
|
||||
from shared.types.worker.shards import ShardMeta
|
||||
|
||||
|
||||
class DownloadProgressData(BaseModel):
|
||||
@@ -80,6 +80,6 @@ DownloadEffectHandler = Callable[
|
||||
def download_shard(
|
||||
model_id: ModelId,
|
||||
model_source: ModelSource,
|
||||
shard_data: ShardData[ShardType],
|
||||
shard_meta: ShardMeta,
|
||||
effect_handlers: Sequence[DownloadEffectHandler],
|
||||
) -> None: ...
|
||||
|
||||
13
shared/types/worker/mlx.py
Normal file
13
shared/types/worker/mlx.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
# TODO: Is this the right place for this? Host is consumed by worker, but typically stored in the master
|
||||
class Host(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
@field_validator('port')
|
||||
def check_port(cls, v: int) -> int:
|
||||
if not (0 <= v <= 65535):
|
||||
raise ValueError("Port must be between 0 and 65535")
|
||||
return v
|
||||
55
shared/types/worker/resource_monitor.py
Normal file
55
shared/types/worker/resource_monitor.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Coroutine
|
||||
|
||||
import asyncio
|
||||
|
||||
from shared.types.events.events import ResourceProfiledEvent
|
||||
from shared.types.profiling.common import NodePerformanceProfile, MemoryPerformanceProfile, SystemPerformanceProfile
|
||||
|
||||
class EventLog:
|
||||
def append(self, event: ResourceProfiledEvent) -> None:
|
||||
...
|
||||
|
||||
class ResourceCollector(ABC):
|
||||
"""
|
||||
Details a single resource (or resource type) that is being monitored by the resource monitor.
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def collect(self) -> NodePerformanceProfile:
|
||||
...
|
||||
|
||||
class SystemResourceCollector(ResourceCollector):
|
||||
def __init__(self):
|
||||
super().__init__('system')
|
||||
|
||||
async def collect(self) -> SystemPerformanceProfile:
|
||||
...
|
||||
|
||||
class MemoryResourceCollector(ResourceCollector):
|
||||
def __init__(self):
|
||||
super().__init__('memory')
|
||||
|
||||
async def collect(self) -> MemoryPerformanceProfile:
|
||||
...
|
||||
|
||||
class ResourceMonitor:
|
||||
def __init__(self, event_outbox: EventLog):
|
||||
self.event_outbox: EventLog = event_outbox
|
||||
|
||||
self.collectors: list[ResourceCollector] = [
|
||||
SystemResourceCollector(),
|
||||
MemoryResourceCollector(),
|
||||
]
|
||||
|
||||
async def collect(self) -> list[NodePerformanceProfile]:
|
||||
tasks: list[Coroutine[None, None, NodePerformanceProfile]] = [
|
||||
collector.collect() for collector in self.collectors
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def collect_and_publish(self) -> None:
|
||||
profiles = await self.collect()
|
||||
for profile in profiles:
|
||||
self.event_outbox.append(profile.to_event())
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Generic, Literal, TypeVar
|
||||
from typing import Generic, Literal, TypeVar, Self
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -8,7 +8,7 @@ from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
|
||||
from shared.types.worker.shards import ShardData, ShardType
|
||||
from shared.types.worker.shards import BaseModelShardMeta, PartitionStrategyT
|
||||
|
||||
|
||||
class RunnerStateType(str, Enum):
|
||||
@@ -55,13 +55,17 @@ class RunnerData(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RunnerPlacement(BaseModel):
|
||||
# Runner placement must be consistent in its partitioning strategy across all shards.
|
||||
# Using a generic type parameter enforces this constraint at type-checking time.
|
||||
|
||||
|
||||
class RunnerPlacement(BaseModel, Generic[PartitionStrategyT]):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardData[ShardType]]
|
||||
runner_to_shard: Mapping[RunnerId, BaseModelShardMeta[PartitionStrategyT]]
|
||||
node_to_runner: Mapping[NodeId, Sequence[RunnerId]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "RunnerPlacement":
|
||||
def validate_runners_exist(self) -> Self:
|
||||
for runners in self.node_to_runner.values():
|
||||
for runner_id in runners:
|
||||
if runner_id not in self.runner_to_shard:
|
||||
|
||||
@@ -1,15 +1,47 @@
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, DirectoryPath, Field, TypeAdapter
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models.common import ModelId
|
||||
|
||||
class PartitionStrategy(str, Enum):
|
||||
pipeline = 'pipeline'
|
||||
|
||||
PartitionStrategyT = TypeVar(name='PartitionStrategyT', bound=PartitionStrategy)
|
||||
|
||||
class BaseModelShardMeta(BaseModel, Generic[PartitionStrategyT]):
|
||||
"""
|
||||
Defines a specific shard of the model that is ready to be run on a device.
|
||||
Replaces previous `Shard` object.
|
||||
"""
|
||||
device_rank: int
|
||||
world_size: int
|
||||
model_id: ModelId
|
||||
model_path: DirectoryPath # pydantic DirectoryPath ensures that the directory exists.
|
||||
|
||||
class PipelineShardMeta(BaseModelShardMeta[PartitionStrategy.pipeline]):
|
||||
"""
|
||||
Pipeline parallelism shard meta.
|
||||
"""
|
||||
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
|
||||
default=PartitionStrategy.pipeline, frozen=True
|
||||
)
|
||||
start_layer: Annotated[int, Field(ge=0)]
|
||||
end_layer: Annotated[int, Field(ge=0)]
|
||||
|
||||
ShardMeta = Annotated[
|
||||
PipelineShardMeta,
|
||||
Field(discriminator="partition_strategy")
|
||||
]
|
||||
ShardMetaAdapter: TypeAdapter[ShardMeta] = TypeAdapter(ShardMeta)
|
||||
|
||||
|
||||
class ShardType(str, Enum):
|
||||
PipelineParallel = "PipelineParallel"
|
||||
|
||||
|
||||
ShardTypeT = TypeVar("ShardTypeT", bound=ShardType)
|
||||
|
||||
|
||||
class ShardData(BaseModel, Generic[ShardTypeT]):
|
||||
shard_type: ShardTypeT
|
||||
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
|
||||
"""
|
||||
A shard placement is the description of a model distributed across a set of nodes.
|
||||
The Generic[PartitionStrategyT] enforces that the shard assignments all use the same partition strategy.
|
||||
"""
|
||||
model_id: ModelId
|
||||
shard_assignments: dict[NodeId, BaseModelShardMeta[PartitionStrategyT]]
|
||||
|
||||
8
shared/utils.py
Normal file
8
shared/utils.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def ensure_type(obj: Any, expected_type: Type[T]) -> T:
|
||||
if not isinstance(obj, expected_type):
|
||||
raise TypeError(f"Expected {expected_type}, got {type(obj)}")
|
||||
return obj
|
||||
Reference in New Issue
Block a user