fix: Many Fixes

This commit is contained in:
Arbion Halili
2025-07-16 13:35:31 +01:00
parent d9b9aa7ad2
commit 520b1122a3
26 changed files with 698 additions and 594 deletions

View File

@@ -0,0 +1,114 @@
from typing import Protocol, cast, override
import mlx.core as mx
import mlx.nn as nn
from shared.types.worker.shards import PipelineShardMetadata
class IdentityLayer(nn.Module):
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class PipelineFirstLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__()
self.original_layer: _LayerCallable = original_layer
self.r: int = r
self.s: int = s
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1))
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__()
self.original_layer: _LayerCallable = original_layer
self.r: int = r
self.s: int = s
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(output, (self.r + 1) % self.s)
output = mx.distributed.all_gather(output)[-output.shape[0] :] # pyright: ignore[reportUnknownMemberType]
return output
def inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
if isinstance(inner, nn.Module):
return inner
inner = getattr(model, "transformer", None)
if isinstance(inner, nn.Module):
return inner
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
# def auto_parallel(model: nn.Module, rank: int, size: int, start_layer: int, end_layer: int) -> nn.Module:
def auto_parallel(
model: nn.Module, model_shard_meta: PipelineShardMetadata
) -> nn.Module:
"""
Automatically parallelize a model across multiple devices.
Args:
model: The model to parallelize (must have a 'layers' or 'h' property)
model_shard_meta: The metadata for the model shard
Returns:
The parallelized model
"""
inner_model_instance: nn.Module = inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
else:
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers[: model_shard_meta.start_layer] = [
IdentityLayer() for _ in range(model_shard_meta.start_layer)
]
layers[model_shard_meta.end_layer :] = [
IdentityLayer() for _ in range(len(layers) - model_shard_meta.end_layer)
]
layers[model_shard_meta.start_layer] = PipelineFirstLayer(
layers[model_shard_meta.start_layer],
model_shard_meta.device_rank,
model_shard_meta.world_size,
)
layers[model_shard_meta.end_layer - 1] = PipelineLastLayer(
layers[model_shard_meta.end_layer - 1],
model_shard_meta.device_rank,
model_shard_meta.world_size,
)
# At this point `layers` *must* be a concrete list.
assert isinstance(layers, list), (
"Expected a list of layers after auto-parallel initialisation"
)
return model

View File

@@ -21,15 +21,20 @@ from shared.types.worker.shards import ShardMeta
from worker.runner.communication import runner_print
def mx_barrier():
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu)))) # type: ignore
def mx_barrier():
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu))
)
)
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
return cls(root=[str(host) for host in hosts])
def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
"""
Initialize the MLX distributed (runs in thread pool)
@@ -37,10 +42,10 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
runner_print(f"Starting initialization for rank {rank}")
# Setup distributed environment
hostfile = f"./hosts_{rank}.json" # TODO: this needs to be unique?
hostfile = f"./hosts_{rank}.json" # TODO: this needs to be unique?
hosts_json = HostList.from_hosts(hosts).model_dump_json()
runner_print(f'rank {rank} hostfile: {hostfile} hosts: {hosts_json}')
runner_print(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
@@ -55,6 +60,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
return group
def initialize_mlx(
model_shard_meta: ShardMeta,
hosts: list[Host],
@@ -71,8 +77,9 @@ def initialize_mlx(
return model, tokenizer, sampler
def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWrapper]:
runner_print(f'loading model from {model_shard_meta.model_path}')
runner_print(f"loading model from {model_shard_meta.model_path}")
model, config = load_model(model_shard_meta.model_path, lazy=True, strict=False)
@@ -102,9 +109,11 @@ async def apply_chat_template(
for message in messages_dicts:
filtered_message = {k: v for k, v in message.items() if v is not None}
# Verify we have exactly the expected keys
assert set(filtered_message.keys()) == {'role', 'content'}, f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}"
assert set(filtered_message.keys()) == {"role", "content"}, (
f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}"
)
formatted_messages.append(filtered_message)
messages_dicts = formatted_messages
prompt: str = await loop.run_in_executor(
@@ -113,7 +122,7 @@ async def apply_chat_template(
messages_dicts,
tokenize=False,
add_generation_prompt=True,
)
),
)
return prompt
return prompt

View File

@@ -26,10 +26,22 @@ class MasterInvalidCommandReceivedLogEntry(
command_name: str
class MasterCommandRunnerNotRunningLogEntry: ...
class MasterCommandRunnerNotRunningLogEntry(
LogEntry[Literal["master_command_runner_not_running"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_command_runner_not_running"] = (
"master_command_runner_not_running"
)
message: str = "Command Runner Not Running"
class MasterStateManagerStoppedLogEntry: ...
class MasterStateManagerStoppedLogEntry(
LogEntry[Literal["master_state_manager_stopped"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_state_manager_stopped"] = "master_state_manager_stopped"
message: str = "State Manager Stopped"
class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]):

View File

@@ -1,23 +1,16 @@
from asyncio import CancelledError, Lock, Task, create_task
from asyncio import Queue as AsyncQueue
from contextlib import asynccontextmanager
from logging import Logger, LogRecord
from queue import Queue as PQueue
from typing import Callable, Sequence
from typing import Literal
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from fastapi import FastAPI
from master.commands import ExternalCommand
from master.env import MasterEnvironmentSchema
from master.logging import (
MasterCommandRunnerNotRunningLogEntry,
MasterStateManagerStoppedLogEntry,
MasterUninitializedLogEntry,
)
from master.router import QueueMapping
from master.state_manager.sync import SyncStateManagerMapping
from shared.constants import EXO_MASTER_STATE
from shared.event_loops.main import NodeEventLoopProtocol
from shared.logger import (
FilterLogByType,
LogEntryType,
@@ -27,11 +20,7 @@ from shared.logger import (
log,
)
from shared.types.events.common import (
Apply,
EventCategory,
EventFromEventLog,
EventPublisher,
State,
EventCategoryEnum,
)
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
@@ -63,93 +52,20 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
return data
# Safety on Apply.
def safely_apply[T: EventCategory](
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
) -> State[T]:
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
state = state.model_copy()
for event in sorted_events:
if event.idx_in_log <= state.last_event_applied_idx:
continue
state.last_event_applied_idx = event.idx_in_log
state = apply_fn(state, event)
return state
# What The Master Cares About
MasterEventCategories = (
Literal[EventCategoryEnum.MutatesControlPlaneState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
| Literal[EventCategoryEnum.MutatesDataPlaneState]
)
class MasterEventLoop:
"""Thread-safe manager for MasterState with independent event loop."""
def __init__(
self,
initial_state: MasterState,
push_events_to_queue: Callable[[QueueMapping], None],
event_publisher: EventPublisher[EventCategory],
state_managers: SyncStateManagerMapping,
logger: Logger,
):
self._state = initial_state
self._state_lock = Lock()
self._event_queues: QueueMapping
self._command_runner: ...
self._command_run_task: Task[None] | None = None
self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue()
self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue()
self._state_managers: SyncStateManagerMapping
self._state_global_lock: Lock = Lock()
self._push_events_to_queue: Callable[[QueueMapping], None]
self._event_fetch_task: Task[None] | None = None
self._logger = logger
@property
def _is_command_runner_running(self) -> bool:
return self._command_run_task is not None and not self._command_run_task.done()
@property
def _is_event_fetcher_running(self) -> bool:
return self._event_fetch_task is not None and not self._event_fetch_task.done()
async def send_command(
self, command: ExternalCommand
) -> Response | StreamingResponse:
"""Send a command to the background event loop."""
if self._is_command_runner_running:
await self._command_queue.put(command)
return await self._response_queue.get()
else:
log(self._logger, MasterCommandRunnerNotRunningLogEntry())
raise RuntimeError("Command Runner Is Not Running")
async def start(self) -> None:
"""Start the background event loop."""
async def fetch_and_apply_events() -> None:
while True:
async with self._state_global_lock:
for state in self._state_managers.values():
self._push_events_to_queue(self._event_queues)
safely_apply(
state, apply_fn, self._event_queues[state.event_category]
)
self._event_fetch_task = create_task(fetch_and_apply_events())
self._command_run_task = create_task(self._command_runner())
async def stop(self) -> None:
"""Stop the background event loop and persist state."""
if not self._is_command_runner_running or not self._is_event_fetcher_running:
raise RuntimeError("Command Runner Is Not Running")
assert self._command_run_task is not None and self._event_fetch_task is not None
for service in [self._event_fetch_task, self._command_run_task]:
service.cancel()
try:
await service
except CancelledError:
pass
log(self._logger, MasterStateManagerStoppedLogEntry())
# Takes Care Of All States And Events Related To The Master
class MasterEventLoopProtocol(NodeEventLoopProtocol[MasterEventCategories]): ...
@asynccontextmanager
@@ -182,7 +98,7 @@ async def lifespan(app: FastAPI):
cluster_queue,
)
# TODO: Add handlers
# TODO: Add Handlers For Pushing Logs To Remote Services
telemetry_listener = create_queue_listener(telemetry_queue, [])
metrics_listener = create_queue_listener(metrics_queue, [])
cluster_listener = create_queue_listener(cluster_queue, [])
@@ -191,15 +107,13 @@ async def lifespan(app: FastAPI):
metrics_listener.start()
cluster_listener.start()
initial_state = get_master_state(logger)
app.state.master_event_loop = MasterEventLoop(
initial_state, None, None, None, logger
)
await app.state.master_event_loop.start()
# initial_state = get_master_state(logger)
# app.state.master_event_loop = MasterEventLoop()
# await app.state.master_event_loop.start()
yield
await app.state.master_event_loop.stop()
# await app.state.master_event_loop.stop()
app = FastAPI(lifespan=lifespan)

View File

@@ -1,90 +0,0 @@
from asyncio import Queue, gather
from logging import Logger
from typing import Literal, Protocol, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import (
EventCategories,
EventCategory,
EventCategoryEnum,
EventFromEventLog,
narrow_event_from_event_log_type,
)
class QueueMapping(TypedDict):
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
class EventRouterProtocol(Protocol):
queue_map: QueueMapping
start_idx: int
def sync_queues(self) -> None: ...
class EventRouter(EventRouterProtocol):
"""Routes events to appropriate services based on event categories."""
queue_map: QueueMapping
start_idx: int
logger: Logger
async def _get_queue_by_category[T: EventCategory](
self, category: T
) -> Queue[EventFromEventLog[T]]:
"""Get the queue for a given category."""
category_str: str = category.value
queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str]
return queue
async def _process_events[T: EventCategory](self, category: T) -> None:
"""Process events for a given domain."""
queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category)
events_to_process: list[EventFromEventLog[T]] = []
while not queue.empty():
events_to_process.append(await queue.get())
for event_to_process in events_to_process:
await self.queue_map[category.value].put(event_to_process)
return None
async def _submit_events[T: EventCategory | EventCategories](
self, events: list[EventFromEventLog[T]]
) -> None:
"""Route multiple events to their appropriate services."""
for event in events:
if isinstance(event.event.event_category, EventCategory):
q1: Queue[EventFromEventLog[T]] = self.queue_map[
event.event.event_category.value
]
await q1.put(event)
elif isinstance(event.event.event_category, EventCategories):
for category in event.event.event_category:
narrow_event = narrow_event_from_event_log_type(event, category)
q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value]
await q2.put(narrow_event)
await gather(*[self._process_events(domain) for domain in EventCategoryEnum])

View File

@@ -10,8 +10,8 @@ from master.logging import (
StateUpdateLoopStartedLogEntry,
StateUpdateLoopStoppedLogEntry,
)
from master.router import check_keys_in_map_match_enum_values
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.constants import get_error_reporting_message
from shared.logger import log
from shared.types.events.common import (
Apply,
@@ -74,7 +74,7 @@ class AsyncStateManager[EventCategoryT: EventCategory](Protocol):
raise RuntimeError("State Update Loop Not Running")
assert self._task is not None, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
"BUG: is_running is True but _task is None, this should never happen!"
)
self._task.cancel()

View File

@@ -21,7 +21,8 @@ def get_caller_module_name() -> str:
return mod.__name__
EXO_ERROR_REPORTING_MESSAGE = lambda: (
f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n"
f"The module that raised the error was: {get_caller_module_name()}"
)
def get_error_reporting_message() -> str:
return (
f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n"
f"The module that raised the error was: {get_caller_module_name()}"
)

View File

@@ -2,8 +2,15 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter
from shared.types.common import NewUUID
class ExternalCommandId(NewUUID):
pass
class BaseExternalCommand[T: str](BaseModel):
command_id: ExternalCommandId
command_type: T

121
shared/event_loops/main.py Normal file
View File

@@ -0,0 +1,121 @@
from asyncio import Lock, Task
from asyncio import Queue as AsyncQueue
from collections.abc import MutableMapping
from logging import Logger
from typing import Any, Hashable, Mapping, Protocol, Sequence
from fastapi.responses import Response, StreamingResponse
from shared.event_loops.commands import ExternalCommand
from shared.types.events.common import Apply, EventCategory, EventFromEventLog, State
class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
__slots__ = ("_store",)
required_keys: frozenset[K] = frozenset()
def __init__(self, data: Mapping[K, V]):
missing = self.required_keys - data.keys()
extra = data.keys() - self.required_keys
if missing or extra:
raise ValueError(f"missing={missing!r}, extra={extra!r}")
self._store: dict[K, V] = dict(data)
def __getitem__(self, k: K) -> V:
return self._store[k]
def __setitem__(self, k: K, v: V) -> None:
self._store[k] = v
def __delitem__(self, k: K) -> None:
del self._store[k]
def __iter__(self):
return iter(self._store)
def __len__(self) -> int:
return len(self._store)
# Safety on Apply.
def safely_apply[T: EventCategory](
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
) -> State[T]:
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
state = state.model_copy()
for event in sorted_events:
if event.idx_in_log <= state.last_event_applied_idx:
continue
state.last_event_applied_idx = event.idx_in_log
state = apply_fn(state, event)
return state
class NodeCommandLoopProtocol(Protocol):
_command_runner: Task[Any] | None = None
_command_queue: AsyncQueue[ExternalCommand]
_response_queue: AsyncQueue[Response | StreamingResponse]
_logger: Logger
@property
def is_command_runner_running(self) -> bool:
return self._command_runner is not None and not self._command_runner.done()
async def start_command_runner(self) -> None: ...
async def stop_command_runner(self) -> None: ...
async def push_command(self, command: ExternalCommand) -> None: ...
async def pop_response(self) -> Response | StreamingResponse: ...
async def _handle_command(self, command: ExternalCommand) -> None: ...
class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol):
_event_fetcher: Task[Any] | None = None
_event_queues: ExhaustiveMapping[
EventCategoryT, AsyncQueue[EventFromEventLog[EventCategory]]
]
_logger: Logger
@property
async def is_event_fetcher_running(self) -> bool:
return self._event_fetcher is not None and not self._event_fetcher.done()
async def start_event_fetcher(self) -> None: ...
async def stop_event_fetcher(self) -> None: ...
class NodeStateStorageProtocol[EventCategoryT: EventCategory](Protocol):
_state_managers: ExhaustiveMapping[EventCategoryT, State[EventCategoryT]]
_state_lock: Lock
_logger: Logger
async def _read_state(
self, event_category: EventCategoryT
) -> State[EventCategoryT]: ...
class NodeStateManagerProtocol[EventCategoryT: EventCategory](
NodeEventGetterProtocol[EventCategoryT], NodeStateStorageProtocol[EventCategoryT]
):
_state_manager: Task[Any] | None = None
_logger: Logger
@property
async def is_state_manager_running(self) -> bool:
is_task_running = (
self._state_manager is not None and not self._state_manager.done()
)
return (
is_task_running
and await self.is_event_fetcher_running
and await self.is_state_manager_running
)
async def start_state_manager(self) -> None: ...
async def stop_state_manager(self) -> None: ...
async def _apply_queued_events(self) -> None: ...
class NodeEventLoopProtocol[EventCategoryT: EventCategory](
NodeCommandLoopProtocol, NodeStateManagerProtocol[EventCategoryT]
): ...

View File

@@ -0,0 +1,78 @@
from asyncio.queues import Queue
from typing import Sequence, cast, get_args
from shared.event_loops.main import ExhaustiveMapping
from shared.types.events.common import (
EventCategories,
EventCategory,
EventCategoryEnum,
EventFromEventLog,
narrow_event_from_event_log_type,
)
"""
from asyncio import gather
from logging import Logger
from typing import Literal, Protocol, Sequence, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import EventCategoryEnum
"""
"""
class EventQueues(TypedDict):
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
check_keys_in_map_match_enum_values(EventQueues, EventCategoryEnum)
"""
async def route_events[UnionOfRelevantEvents: EventCategory](
queue_map: ExhaustiveMapping[
UnionOfRelevantEvents, Queue[EventFromEventLog[EventCategory]]
],
events: Sequence[EventFromEventLog[EventCategory | EventCategories]],
) -> None:
"""Route an event to the appropriate queue."""
tuple_of_categories: tuple[EventCategoryEnum, ...] = get_args(UnionOfRelevantEvents)
print(tuple_of_categories)
for event in events:
if isinstance(event.event.event_category, EventCategoryEnum):
category: EventCategory = event.event.event_category
if category not in tuple_of_categories:
continue
narrowed_event = narrow_event_from_event_log_type(event, category)
q1: Queue[EventFromEventLog[EventCategory]] = queue_map[
cast(UnionOfRelevantEvents, category)
] # TODO: make casting unnecessary
await q1.put(narrowed_event)
else:
for category in event.event.event_category:
if category not in tuple_of_categories:
continue
narrow_event = narrow_event_from_event_log_type(event, category)
q2 = queue_map[
cast(UnionOfRelevantEvents, category)
] # TODO: make casting unnecessary
await q2.put(narrow_event)

View File

@@ -4,7 +4,7 @@ from collections.abc import Sequence, Set
from queue import Queue
from typing import Annotated
from pydantic import Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from rich.logging import RichHandler
from master.logging import MasterLogEntries
@@ -28,12 +28,6 @@ class FilterLogByType(logging.Filter):
return True
class LogEntryType(str, Enum):
telemetry = "telemetry"
metrics = "metrics"
cluster = "cluster"
class LogEntry(BaseModel):
event_type: Set[LogEntryType]

View File

@@ -1,93 +0,0 @@
from typing import Protocol, cast, override
import mlx.core as mx
import mlx.nn as nn
from shared.types.worker.shards import PipelineShardMeta
class IdentityLayer(nn.Module):
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class PipelineFirstLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__()
self.original_layer: _LayerCallable = original_layer
self.r: int = r
self.s: int = s
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1))
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__()
self.original_layer: _LayerCallable = original_layer
self.r: int = r
self.s: int = s
@override
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(output, (self.r + 1) % self.s)
output = mx.distributed.all_gather(output)[-output.shape[0]:] # pyright: ignore[reportUnknownMemberType]
return output
def inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, 'model', None)
if isinstance(inner, nn.Module):
return inner
inner = getattr(model, 'transformer', None)
if isinstance(inner, nn.Module):
return inner
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
# def auto_parallel(model: nn.Module, rank: int, size: int, start_layer: int, end_layer: int) -> nn.Module:
def auto_parallel(model: nn.Module, model_shard_meta: PipelineShardMeta) -> nn.Module:
"""
Automatically parallelize a model across multiple devices.
Args:
model: The model to parallelize (must have a 'layers' or 'h' property)
model_shard_meta: The metadata for the model shard
Returns:
The parallelized model
"""
inner_model_instance: nn.Module = inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
if hasattr(inner_model_instance, 'layers'):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
else:
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers[:model_shard_meta.start_layer] = [IdentityLayer() for _ in range(model_shard_meta.start_layer)]
layers[model_shard_meta.end_layer:] = [IdentityLayer() for _ in range(len(layers) - model_shard_meta.end_layer)]
layers[model_shard_meta.start_layer] = PipelineFirstLayer(layers[model_shard_meta.start_layer], model_shard_meta.device_rank, model_shard_meta.world_size)
layers[model_shard_meta.end_layer - 1] = PipelineLastLayer(layers[model_shard_meta.end_layer - 1], model_shard_meta.device_rank, model_shard_meta.world_size)
# At this point `layers` *must* be a concrete list.
assert isinstance(layers, list), "Expected a list of layers after auto-parallel initialisation"
return model

View File

@@ -1,5 +1,6 @@
from enum import Enum, StrEnum
from typing import (
Any,
Callable,
FrozenSet,
Literal,
@@ -205,9 +206,7 @@ class StateAndEvent[EventCategoryT: EventCategory](NamedTuple):
type EffectHandler[EventCategoryT: EventCategory] = Callable[
[StateAndEvent[EventCategoryT], State[EventCategoryT]], None
]
type EventPublisher[EventCategoryT: EventCategory] = Callable[
[Event[EventCategoryT]], None
]
type EventPublisher = Callable[[Event[Any]], None]
# A component that can publish events
@@ -224,7 +223,7 @@ class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol):
# A component that can get the effect handler for a saga
def get_saga_effect_handler[EventCategoryT: EventCategory](
saga: Saga[EventCategoryT], event_publisher: EventPublisher[EventCategoryT]
saga: Saga[EventCategoryT], event_publisher: EventPublisher
) -> EffectHandler[EventCategoryT]:
def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None:
trigger_state, trigger_event = state_and_event
@@ -236,7 +235,7 @@ def get_saga_effect_handler[EventCategoryT: EventCategory](
def get_effects_from_sagas[EventCategoryT: EventCategory](
sagas: Sequence[Saga[EventCategoryT]],
event_publisher: EventPublisher[EventCategoryT],
event_publisher: EventPublisher,
) -> Sequence[EffectHandler[EventCategoryT]]:
return [get_saga_effect_handler(saga, event_publisher) for saga in sagas]

View File

@@ -41,10 +41,10 @@ from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType
MLXEvent = Event[
frozenset(
{
(
EventCategoryEnum.MutatesTaskState,
EventCategoryEnum.MutatesControlPlaneState,
}
)
)
]
TaskEvent = Event[EventCategoryEnum.MutatesTaskState]

View File

@@ -3,7 +3,7 @@ from typing import Annotated, Any, Mapping, Type, get_args
from pydantic import Field, TypeAdapter
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.constants import get_error_reporting_message
from shared.types.events.common import (
ControlPlaneEventTypes,
DataPlaneEventTypes,
@@ -50,7 +50,6 @@ class EventTypeNames(StrEnum):
check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames)
"""
EventRegistry: Mapping[EventTypes, Type[Any]] = {
TaskEventTypes.TaskCreated: TaskCreated,
TaskEventTypes.TaskStateUpdated: TaskStateUpdated,
@@ -78,7 +77,7 @@ def check_registry_has_all_event_types() -> None:
missing_event_types = set(event_types) - set(EventRegistry.keys())
assert not missing_event_types, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"There's an event missing from the registry: {missing_event_types}"
)
@@ -91,14 +90,14 @@ def check_union_of_all_events_is_consistent_with_registry(
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
assert not missing_from_union, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"Event classes in registry are missing from all_events union: {missing_from_union}"
)
extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry
assert not extra_in_union, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"Event classes in all_events union are missing from registry: {extra_in_union}"
)

View File

@@ -2,7 +2,7 @@ from enum import Enum, StrEnum
from types import UnionType
from typing import Any, LiteralString, Sequence, Set, Type, get_args
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.constants import get_error_reporting_message
def check_event_type_union_is_consistent_with_registry(
@@ -20,7 +20,7 @@ def check_event_type_union_is_consistent_with_registry(
for tag_of_event_type in event_types_inferred_from_registry:
event_type = type(tag_of_event_type)
assert event_type in event_types_inferred_from_union, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"There's a mismatch between the registry of event types and the union of possible event types."
f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}."
)
@@ -36,7 +36,7 @@ def check_event_categories_are_defined_for_all_event_types(
]
tag_of_event_categories: list[str] = list(event_categories.__members__.values())
assert tag_of_event_categories == expected_category_tags, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"The values of the enum EventCategories are not named after the event type enums."
f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}"
f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}"
@@ -61,7 +61,7 @@ def assert_literal_union_covers_enum[TEnum: StrEnum](
literal_values: Set[Any] = _flatten(literal_union)
assert enum_values == literal_values, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
f"{get_error_reporting_message()}"
f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n"
f"These are the missing values: {enum_values - literal_values}\n"
f"These are the extra values: {literal_values - enum_values}\n"

View File

@@ -2,14 +2,14 @@ from collections.abc import Mapping
from shared.types.common import NodeId
from shared.types.events.common import (
EventCategory,
EventCategoryEnum,
State,
)
from shared.types.states.shared import SharedState
from shared.types.worker.common import NodeStatus
class NodeStatusState(State[EventCategory.MutatesControlPlaneState]):
class NodeStatusState(State[EventCategoryEnum.MutatesControlPlaneState]):
node_status: Mapping[NodeId, NodeStatus]

View File

@@ -29,7 +29,7 @@ class ChatCompletionNonStreamingTask(TaskParams[TaskType.ChatCompletionNonStream
task_type: Literal[TaskType.ChatCompletionNonStreaming] = (
TaskType.ChatCompletionNonStreaming
)
task_data: openai.completion_create_params.CompletionCreateParams
task_data: openai.completion_create_params.CompletionCreateParamsNonStreaming
@final
@@ -37,7 +37,7 @@ class ChatCompletionStreamingTask(TaskParams[TaskType.ChatCompletionStreaming]):
task_type: Literal[TaskType.ChatCompletionStreaming] = (
TaskType.ChatCompletionStreaming
)
task_data: openai.completion_create_params.CompletionCreateParams
task_data: openai.completion_create_params.CompletionCreateParamsStreaming
@final
@@ -83,7 +83,7 @@ class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel)
class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel):
task_type: TaskTypeT
task_params: TaskParams[TaskTypeT]
task_stats: TaskState[TaskStatusTypeT, TaskTypeT]
task_state: TaskState[TaskStatusTypeT, TaskTypeT]
on_instance: InstanceId

274
uv.lock generated
View File

@@ -44,11 +44,31 @@ wheels = [
[[package]]
name = "certifi"
version = "2025.6.15"
version = "2025.7.14"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/73/f7/f14b46d4bcd21092d7d3ccef689615220d8a08fb25e564b65d20738e672e/certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b", size = 158753, upload-time = "2025-06-15T02:45:51.329Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b3/76/52c535bcebe74590f296d6c77c86dabf761c41980e1347a2422e4aa2ae41/certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995", size = 163981, upload-time = "2025-07-14T03:29:28.449Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650, upload-time = "2025-06-15T02:45:49.977Z" },
{ url = "https://files.pythonhosted.org/packages/4f/52/34c6cf5bb9285074dc3531c437b3919e825d976fde097a7a73f79e726d03/certifi-2025.7.14-py3-none-any.whl", hash = "sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2", size = 162722, upload-time = "2025-07-14T03:29:26.863Z" },
]
[[package]]
name = "charset-normalizer"
version = "3.4.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ea/12/a93df3366ed32db1d907d7593a94f1fe6293903e3e92967bebd6950ed12c/charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0", size = 199622, upload-time = "2025-05-02T08:32:56.363Z" },
{ url = "https://files.pythonhosted.org/packages/04/93/bf204e6f344c39d9937d3c13c8cd5bbfc266472e51fc8c07cb7f64fcd2de/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf", size = 143435, upload-time = "2025-05-02T08:32:58.551Z" },
{ url = "https://files.pythonhosted.org/packages/22/2a/ea8a2095b0bafa6c5b5a55ffdc2f924455233ee7b91c69b7edfcc9e02284/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e", size = 153653, upload-time = "2025-05-02T08:33:00.342Z" },
{ url = "https://files.pythonhosted.org/packages/b6/57/1b090ff183d13cef485dfbe272e2fe57622a76694061353c59da52c9a659/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1", size = 146231, upload-time = "2025-05-02T08:33:02.081Z" },
{ url = "https://files.pythonhosted.org/packages/e2/28/ffc026b26f441fc67bd21ab7f03b313ab3fe46714a14b516f931abe1a2d8/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c", size = 148243, upload-time = "2025-05-02T08:33:04.063Z" },
{ url = "https://files.pythonhosted.org/packages/c0/0f/9abe9bd191629c33e69e47c6ef45ef99773320e9ad8e9cb08b8ab4a8d4cb/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691", size = 150442, upload-time = "2025-05-02T08:33:06.418Z" },
{ url = "https://files.pythonhosted.org/packages/67/7c/a123bbcedca91d5916c056407f89a7f5e8fdfce12ba825d7d6b9954a1a3c/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0", size = 145147, upload-time = "2025-05-02T08:33:08.183Z" },
{ url = "https://files.pythonhosted.org/packages/ec/fe/1ac556fa4899d967b83e9893788e86b6af4d83e4726511eaaad035e36595/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b", size = 153057, upload-time = "2025-05-02T08:33:09.986Z" },
{ url = "https://files.pythonhosted.org/packages/2b/ff/acfc0b0a70b19e3e54febdd5301a98b72fa07635e56f24f60502e954c461/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff", size = 156454, upload-time = "2025-05-02T08:33:11.814Z" },
{ url = "https://files.pythonhosted.org/packages/92/08/95b458ce9c740d0645feb0e96cea1f5ec946ea9c580a94adfe0b617f3573/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b", size = 154174, upload-time = "2025-05-02T08:33:13.707Z" },
{ url = "https://files.pythonhosted.org/packages/78/be/8392efc43487ac051eee6c36d5fbd63032d78f7728cb37aebcc98191f1ff/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148", size = 149166, upload-time = "2025-05-02T08:33:15.458Z" },
{ url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
]
[[package]]
@@ -173,6 +193,20 @@ requires-dist = [
{ name = "mlx-lm", specifier = ">=0.25.3" },
]
[[package]]
name = "fastapi"
version = "0.116.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "starlette", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" },
]
[[package]]
name = "filelock"
version = "3.18.0"
@@ -184,11 +218,11 @@ wheels = [
[[package]]
name = "fsspec"
version = "2025.5.1"
version = "2025.7.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/00/f7/27f15d41f0ed38e8fcc488584b57e902b331da7f7c6dcda53721b15838fc/fsspec-2025.5.1.tar.gz", hash = "sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475", size = 303033, upload-time = "2025-05-24T12:03:23.792Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432, upload-time = "2025-07-15T16:05:21.19Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/61/78c7b3851add1481b048b5fdc29067397a1784e2910592bc81bb3f608635/fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462", size = 199052, upload-time = "2025-05-24T12:03:21.66Z" },
{ url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597, upload-time = "2025-07-15T16:05:19.529Z" },
]
[[package]]
@@ -244,7 +278,7 @@ wheels = [
[[package]]
name = "huggingface-hub"
version = "0.33.2"
version = "0.33.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -256,69 +290,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4b/9e/9366b7349fc125dd68b9d384a0fea84d67b7497753fe92c71b67e13f47c4/huggingface_hub-0.33.4.tar.gz", hash = "sha256:6af13478deae120e765bfd92adad0ae1aec1ad8c439b46f23058ad5956cbca0a", size = 426674, upload-time = "2025-07-11T12:32:48.694Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" },
]
[[package]]
name = "idna"
version = "3.10"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
]
[[package]]
name = "fastapi"
version = "0.116.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "starlette", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/20/38/e1da78736143fd885c36213a3ccc493c384ae8fea6a0f0bc272ef42ebea8/fastapi-0.116.0.tar.gz", hash = "sha256:80dc0794627af0390353a6d1171618276616310d37d24faba6648398e57d687a", size = 296518, upload-time = "2025-07-07T15:09:27.82Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2f/68/d80347fe2360445b5f58cf290e588a4729746e7501080947e6cdae114b1f/fastapi-0.116.0-py3-none-any.whl", hash = "sha256:fdcc9ed272eaef038952923bef2b735c02372402d1203ee1210af4eea7a78d2b", size = 95625, upload-time = "2025-07-07T15:09:26.348Z" },
]
[[package]]
name = "h11"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
]
[[package]]
name = "httpcore"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
{ url = "https://files.pythonhosted.org/packages/46/7b/98daa50a2db034cab6cd23a3de04fa2358cb691593d28e9130203eb7a805/huggingface_hub-0.33.4-py3-none-any.whl", hash = "sha256:09f9f4e7ca62547c70f8b82767eefadd2667f4e116acba2e3e62a5a81815a7bb", size = 515339, upload-time = "2025-07-11T12:32:46.346Z" },
]
[[package]]
@@ -339,6 +313,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "jinja2"
version = "3.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
]
[[package]]
name = "jiter"
version = "0.10.0"
@@ -446,7 +432,7 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.25.3"
version = "0.26.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -454,11 +440,11 @@ dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", extra = ["sentencepiece"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ec/bc/0c3f69a8ff78fc8152985be99b2f83dc7e902b9b96ff5260c6a4958c10f1/mlx_lm-0.25.3.tar.gz", hash = "sha256:40ea0a2849abd804a40a3e388627ae5327918a8656287022610150fd453a2242", size = 154221, upload-time = "2025-07-01T03:04:07.056Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8d/aa/a2f02e67736a2bf57acefb3a1a342005586f1be8d7b2fb37ca5f3d4f3049/mlx_lm-0.26.0.tar.gz", hash = "sha256:78980ad994baf976779cc1c34c0d55c1c6b63dffef4899d67fec240d0c443b52", size = 159064, upload-time = "2025-07-08T20:21:31.393Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/58/ce/3484a973943572461765977231e3b9b68876a8d7e16c3e6110b81c180a89/mlx_lm-0.25.3-py3-none-any.whl", hash = "sha256:56a84f1ae4a3581b13c84c4d8edaa6704b971b40090b725dfc3b719b522ccc2b", size = 203913, upload-time = "2025-07-01T03:04:05.928Z" },
{ url = "https://files.pythonhosted.org/packages/08/e7/d0e576397b61bf90a0bb27819443f723258acd8dd1207684fdef29243ce4/mlx_lm-0.26.0-py3-none-any.whl", hash = "sha256:b00294c26242cd50db4b6e3ec3a2baf1cfdf8ca49a5e6057dce14642fabe0d21", size = 217671, upload-time = "2025-07-08T20:21:29.448Z" },
]
[[package]]
@@ -496,7 +482,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.93.0"
version = "1.96.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -508,9 +494,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e4/d7/e91c6a9cf71726420cddf539852ee4c29176ebb716a702d9118d0409fd8e/openai-1.93.0.tar.gz", hash = "sha256:988f31ade95e1ff0585af11cc5a64510225e4f5cd392698c675d0a9265b8e337", size = 486573, upload-time = "2025-06-27T21:21:39.421Z" }
sdist = { url = "https://files.pythonhosted.org/packages/2f/b5/18fd5e1b6b6c7dca52d60307b3637f9e9e3206a8041a9c8028985dbc6260/openai-1.96.1.tar.gz", hash = "sha256:6d505b5cc550e036bfa3fe99d6cff565b11491d12378d4c353f92ef72b0a408a", size = 489065, upload-time = "2025-07-15T21:39:37.215Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/46/a10d9df4673df56f71201d129ba1cb19eaff3366d08c8664d61a7df52e65/openai-1.93.0-py3-none-any.whl", hash = "sha256:3d746fe5498f0dd72e0d9ab706f26c91c0f646bf7459e5629af8ba7c9dbdf090", size = 755038, upload-time = "2025-06-27T21:21:37.532Z" },
{ url = "https://files.pythonhosted.org/packages/4f/57/325bbdbdc27b47309be35cb4e0eb8980b0c1bc997194c797c3691d88ae41/openai-1.96.1-py3-none-any.whl", hash = "sha256:0afaab2019bae8e145e7a1baf6953167084f019dd15042c65edd117398c1eb1c", size = 757454, upload-time = "2025-07-15T21:39:34.517Z" },
]
[[package]]
@@ -617,14 +603,14 @@ wheels = [
[[package]]
name = "pytest-asyncio"
version = "1.0.0"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960, upload-time = "2025-05-26T04:54:40.484Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" },
{ url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" },
]
[[package]]
@@ -693,24 +679,43 @@ wheels = [
[[package]]
name = "ruff"
version = "0.12.2"
version = "0.12.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6c/3d/d9a195676f25d00dbfcf3cf95fdd4c685c497fcfa7e862a44ac5e4e96480/ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e", size = 4432239, upload-time = "2025-07-03T16:40:19.566Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/74/b6/2098d0126d2d3318fd5bec3ad40d06c25d377d95749f7a0c5af17129b3b1/ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be", size = 10369761, upload-time = "2025-07-03T16:39:38.847Z" },
{ url = "https://files.pythonhosted.org/packages/b1/4b/5da0142033dbe155dc598cfb99262d8ee2449d76920ea92c4eeb9547c208/ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e", size = 11155659, upload-time = "2025-07-03T16:39:42.294Z" },
{ url = "https://files.pythonhosted.org/packages/3e/21/967b82550a503d7c5c5c127d11c935344b35e8c521f52915fc858fb3e473/ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc", size = 10537769, upload-time = "2025-07-03T16:39:44.75Z" },
{ url = "https://files.pythonhosted.org/packages/33/91/00cff7102e2ec71a4890fb7ba1803f2cdb122d82787c7d7cf8041fe8cbc1/ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922", size = 10717602, upload-time = "2025-07-03T16:39:47.652Z" },
{ url = "https://files.pythonhosted.org/packages/9b/eb/928814daec4e1ba9115858adcda44a637fb9010618721937491e4e2283b8/ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b", size = 10198772, upload-time = "2025-07-03T16:39:49.641Z" },
{ url = "https://files.pythonhosted.org/packages/50/fa/f15089bc20c40f4f72334f9145dde55ab2b680e51afb3b55422effbf2fb6/ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d", size = 11845173, upload-time = "2025-07-03T16:39:52.069Z" },
{ url = "https://files.pythonhosted.org/packages/43/9f/1f6f98f39f2b9302acc161a4a2187b1e3a97634fe918a8e731e591841cf4/ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1", size = 12553002, upload-time = "2025-07-03T16:39:54.551Z" },
{ url = "https://files.pythonhosted.org/packages/d8/70/08991ac46e38ddd231c8f4fd05ef189b1b94be8883e8c0c146a025c20a19/ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4", size = 12171330, upload-time = "2025-07-03T16:39:57.55Z" },
{ url = "https://files.pythonhosted.org/packages/88/a9/5a55266fec474acfd0a1c73285f19dd22461d95a538f29bba02edd07a5d9/ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9", size = 11774717, upload-time = "2025-07-03T16:39:59.78Z" },
{ url = "https://files.pythonhosted.org/packages/87/e5/0c270e458fc73c46c0d0f7cf970bb14786e5fdb88c87b5e423a4bd65232b/ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da", size = 11646659, upload-time = "2025-07-03T16:40:01.934Z" },
{ url = "https://files.pythonhosted.org/packages/b7/b6/45ab96070c9752af37f0be364d849ed70e9ccede07675b0ec4e3ef76b63b/ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce", size = 10604012, upload-time = "2025-07-03T16:40:04.363Z" },
{ url = "https://files.pythonhosted.org/packages/86/91/26a6e6a424eb147cc7627eebae095cfa0b4b337a7c1c413c447c9ebb72fd/ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d", size = 10176799, upload-time = "2025-07-03T16:40:06.514Z" },
{ url = "https://files.pythonhosted.org/packages/f5/0c/9f344583465a61c8918a7cda604226e77b2c548daf8ef7c2bfccf2b37200/ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04", size = 11241507, upload-time = "2025-07-03T16:40:08.708Z" },
{ url = "https://files.pythonhosted.org/packages/1c/b7/99c34ded8fb5f86c0280278fa89a0066c3760edc326e935ce0b1550d315d/ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342", size = 11717609, upload-time = "2025-07-03T16:40:10.836Z" },
{ url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" },
{ url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" },
{ url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" },
{ url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" },
{ url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" },
{ url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" },
{ url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" },
{ url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" },
{ url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" },
{ url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" },
{ url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" },
{ url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" },
{ url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" },
{ url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" },
]
[[package]]
name = "rustworkx"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/c4/6d6ef39e57610d54c5f106dc3dece9eebce8b9d52d561ae092e3aede1b66/rustworkx-0.16.0.tar.gz", hash = "sha256:9f0dcb83f38d5ca2c3a683eb9b6951c8aec3262fbfe5141946a7ee5ba37e0bb6", size = 349524, upload-time = "2025-01-24T01:22:34.686Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f8/70/36f5916aee41ffe4f604ad75742eb1bb1b849fb568e010555f9d159cd93e/rustworkx-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:476a6c67b0142acd941691943750cc6737a48372304489969c2b62d30aaf4c27", size = 2141999, upload-time = "2025-01-24T01:21:50.3Z" },
{ url = "https://files.pythonhosted.org/packages/94/47/7e7c37fb73efcc87be6414b235534605c4008a4cdbd92a61db23b878eecd/rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bef2ef42870f806af93979b457e240f6dfa4f867ca33965c620f3a804409ed3a", size = 1940309, upload-time = "2025-01-24T01:21:52.053Z" },
{ url = "https://files.pythonhosted.org/packages/c6/42/a6d6b3137be55ef1d887becdf6b64b0917c7d437bd483065a88500a55603/rustworkx-0.16.0-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0db3a73bf68b3e66c08322a2fc95d3aa663d037d9b4e49c3509da4898d3529cc", size = 2195350, upload-time = "2025-01-24T01:21:53.785Z" },
{ url = "https://files.pythonhosted.org/packages/59/d2/1bc99df831c132c4b7420a85ce9150e065f4c993798f31b6a4229f238398/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f12a13d7486234fa2a84746d5e41f436bf9df43548043e7a232f48804ff8c61", size = 1971689, upload-time = "2025-01-24T17:09:26.338Z" },
{ url = "https://files.pythonhosted.org/packages/b5/3b/1125e7eb834f4408bcec3cee79947efd504c715fb7ab1876f8cd4bbca497/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89efd5c3a4653ddacc55ca39f28b261d43deec7d678f8f8fc6b76b5087f1dfea", size = 3297342, upload-time = "2025-01-24T03:18:48.885Z" },
{ url = "https://files.pythonhosted.org/packages/4f/e2/e21187b255c6211d71db0d08a44fc16771038b2af41712d66c408d9bec16/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec0c12aac8c54910ace20ac6ada4b890cd39f95f69100514715f8ad7af9041e4", size = 2110107, upload-time = "2025-01-24T01:21:58.884Z" },
{ url = "https://files.pythonhosted.org/packages/3c/79/e3fcff21f31253ea85ef196bf2fcabad7802b11468f7d3a5d592cd0ac789/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d650e39fc1a1534335f7517358ebfc3478bb235428463cfcd7c5750d50377b33", size = 2007544, upload-time = "2025-01-26T04:16:53.807Z" },
{ url = "https://files.pythonhosted.org/packages/67/04/741ed09c2b0dc0f360f85270c1179ed433785372ac9ab6ab26d3dd3ae02d/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:293180b83509ee9bff4c3af7ccc1024f6528d61b65d0cb7320bd31924f10cb71", size = 2172787, upload-time = "2025-01-24T01:22:01.282Z" },
]
[[package]]
@@ -733,12 +738,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/40/ad/2b113098e69c985a3d8fbda4b902778eae4a35b7d5188859b4a63d30c161/safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04", size = 643147, upload-time = "2025-02-26T09:15:11.185Z" },
]
[[package]]
name = "sentencepiece"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/d2/b9c7ca067c26d8ff085d252c89b5f69609ca93fb85a00ede95f4857865d4/sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843", size = 2632106, upload-time = "2024-02-19T17:06:47.428Z" }
[[package]]
name = "sniffio"
version = "1.3.1"
@@ -748,6 +747,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "starlette"
version = "0.47.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072, upload-time = "2025-06-21T04:03:17.337Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" },
]
[[package]]
name = "tokenizers"
version = "0.21.2"
@@ -782,7 +793,7 @@ wheels = [
[[package]]
name = "transformers"
version = "4.53.1"
version = "4.53.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -796,64 +807,9 @@ dependencies = [
{ name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9f/2c/68a0024c311db41bb92d4ec17d22e90b7406a4d28aa18d87662f2bbebcd9/transformers-4.53.1.tar.gz", hash = "sha256:da5a9f66ad480bc2a7f75bc32eaf735fd20ac56af4325ca4ce994021ceb37710", size = 9192189, upload-time = "2025-07-04T08:28:40.571Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4c/67/80f51466ec447028fd84469b208eb742533ce06cc8fad2e3181380199e5c/transformers-4.53.2.tar.gz", hash = "sha256:6c3ed95edfb1cba71c4245758f1b4878c93bf8cde77d076307dacb2cbbd72be2", size = 9201233, upload-time = "2025-07-11T12:39:08.742Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8d/10/8cef2288810a3210659eb3a20711e8387cc35a881a7762ae387806e2d651/transformers-4.53.1-py3-none-any.whl", hash = "sha256:c84f3c3e41c71fdf2c60c8a893e1cd31191b0cb463385f4c276302d2052d837b", size = 10825681, upload-time = "2025-07-04T08:28:37.318Z" },
]
[package.optional-dependencies]
sentencepiece = [
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "rustworkx"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/c4/6d6ef39e57610d54c5f106dc3dece9eebce8b9d52d561ae092e3aede1b66/rustworkx-0.16.0.tar.gz", hash = "sha256:9f0dcb83f38d5ca2c3a683eb9b6951c8aec3262fbfe5141946a7ee5ba37e0bb6", size = 349524, upload-time = "2025-01-24T01:22:34.686Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f8/70/36f5916aee41ffe4f604ad75742eb1bb1b849fb568e010555f9d159cd93e/rustworkx-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:476a6c67b0142acd941691943750cc6737a48372304489969c2b62d30aaf4c27", size = 2141999, upload-time = "2025-01-24T01:21:50.3Z" },
{ url = "https://files.pythonhosted.org/packages/94/47/7e7c37fb73efcc87be6414b235534605c4008a4cdbd92a61db23b878eecd/rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bef2ef42870f806af93979b457e240f6dfa4f867ca33965c620f3a804409ed3a", size = 1940309, upload-time = "2025-01-24T01:21:52.053Z" },
{ url = "https://files.pythonhosted.org/packages/c6/42/a6d6b3137be55ef1d887becdf6b64b0917c7d437bd483065a88500a55603/rustworkx-0.16.0-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0db3a73bf68b3e66c08322a2fc95d3aa663d037d9b4e49c3509da4898d3529cc", size = 2195350, upload-time = "2025-01-24T01:21:53.785Z" },
{ url = "https://files.pythonhosted.org/packages/59/d2/1bc99df831c132c4b7420a85ce9150e065f4c993798f31b6a4229f238398/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f12a13d7486234fa2a84746d5e41f436bf9df43548043e7a232f48804ff8c61", size = 1971689, upload-time = "2025-01-24T17:09:26.338Z" },
{ url = "https://files.pythonhosted.org/packages/b5/3b/1125e7eb834f4408bcec3cee79947efd504c715fb7ab1876f8cd4bbca497/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89efd5c3a4653ddacc55ca39f28b261d43deec7d678f8f8fc6b76b5087f1dfea", size = 3297342, upload-time = "2025-01-24T03:18:48.885Z" },
{ url = "https://files.pythonhosted.org/packages/4f/e2/e21187b255c6211d71db0d08a44fc16771038b2af41712d66c408d9bec16/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec0c12aac8c54910ace20ac6ada4b890cd39f95f69100514715f8ad7af9041e4", size = 2110107, upload-time = "2025-01-24T01:21:58.884Z" },
{ url = "https://files.pythonhosted.org/packages/3c/79/e3fcff21f31253ea85ef196bf2fcabad7802b11468f7d3a5d592cd0ac789/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d650e39fc1a1534335f7517358ebfc3478bb235428463cfcd7c5750d50377b33", size = 2007544, upload-time = "2025-01-26T04:16:53.807Z" },
{ url = "https://files.pythonhosted.org/packages/67/04/741ed09c2b0dc0f360f85270c1179ed433785372ac9ab6ab26d3dd3ae02d/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:293180b83509ee9bff4c3af7ccc1024f6528d61b65d0cb7320bd31924f10cb71", size = 2172787, upload-time = "2025-01-24T01:22:01.282Z" },
]
[[package]]
name = "sniffio"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "starlette"
version = "0.46.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846, upload-time = "2025-04-13T13:56:17.942Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037, upload-time = "2025-04-13T13:56:16.21Z" },
]
[[package]]
name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
{ url = "https://files.pythonhosted.org/packages/96/88/beb33a79a382fcd2aed0be5222bdc47f41e4bfe7aaa90ae1374f1d8ea2af/transformers-4.53.2-py3-none-any.whl", hash = "sha256:db8f4819bb34f000029c73c3c557e7d06fc1b8e612ec142eecdae3947a9c78bf", size = 10826609, upload-time = "2025-07-11T12:39:05.461Z" },
]
[[package]]

View File

@@ -14,13 +14,19 @@ from shared.types.worker.commands_runner import (
### Utils - MESSAGE TO RUNNER
async def supervisor_write_message(proc: asyncio.subprocess.Process, message: RunnerMessage) -> None:
assert proc.stdin is not None, "proc.stdin should not be None when created with stdin=PIPE"
encoded: bytes = message.model_dump_json().encode('utf-8') + b'\n'
async def supervisor_write_message(
proc: asyncio.subprocess.Process, message: RunnerMessage
) -> None:
assert proc.stdin is not None, (
"proc.stdin should not be None when created with stdin=PIPE"
)
encoded: bytes = message.model_dump_json().encode("utf-8") + b"\n"
proc.stdin.write(encoded)
await proc.stdin.drain()
async def runner_read_message() -> RunnerMessage:
loop = asyncio.get_running_loop()
@@ -34,17 +40,24 @@ async def runner_read_message() -> RunnerMessage:
except Exception as e:
raise ValueError(f"Error validating message: {line}") from e
### Utils - RESPONSE FROM RUNNER
def runner_write_response(obj: RunnerResponse) -> None:
encoded: bytes = obj.model_dump_json().encode('utf-8') + b'\n'
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
_ = sys.stdout.buffer.write(encoded)
_ = sys.stdout.buffer.flush()
async def supervisor_read_response(proc: asyncio.subprocess.Process) -> RunnerResponse | None:
assert proc.stdout is not None, "proc.stdout should not be None when created with stdout=PIPE"
async def supervisor_read_response(
proc: asyncio.subprocess.Process,
) -> RunnerResponse | None:
assert proc.stdout is not None, (
"proc.stdout should not be None when created with stdout=PIPE"
)
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=10)
line: str = line_bytes.decode('utf-8').strip()
line: str = line_bytes.decode("utf-8").strip()
if not line:
raise EOFError("No more data to read")
@@ -57,6 +70,7 @@ async def supervisor_read_response(proc: asyncio.subprocess.Process) -> RunnerRe
### Utils - Runner Prints
def runner_print(text: str) -> None:
obj = PrintResponse(
type=RunnerResponseType.PrintResponse,
@@ -65,11 +79,12 @@ def runner_print(text: str) -> None:
runner_write_response(obj)
def runner_write_error(error: Exception) -> None:
error_response: ErrorResponse = ErrorResponse(
type=RunnerResponseType.ErrorResponse,
error_type=type(error).__name__,
error_message=str(error),
traceback=traceback.format_exc(),
type=RunnerResponseType.ErrorResponse,
error_type=type(error).__name__,
error_message=str(error),
traceback=traceback.format_exc(),
)
runner_write_response(error_response)
runner_write_response(error_response)

View File

@@ -11,7 +11,7 @@ import mlx.nn as nn
from mlx_lm.generate import stream_generate # type: ignore
from mlx_lm.tokenizer_utils import TokenizerWrapper
from shared.mlx.utils_mlx import apply_chat_template, initialize_mlx
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx
from shared.openai import FinishReason
from shared.types.tasks.common import (
TaskData,
@@ -58,13 +58,15 @@ async def _mlx_generate(
response = GenerationResponse(
text=generation_response.text,
token=generation_response.token,
finish_reason=cast(FinishReason | None, generation_response.finish_reason), # has to be considered as a FinishReason instead of a str.
finish_reason=cast(
FinishReason | None, generation_response.finish_reason
), # has to be considered as a FinishReason instead of a str.
)
_ = loop.call_soon_threadsafe(queue.put_nowait, response)
except Exception as e:
_ = loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
_ = loop.call_soon_threadsafe(queue.put_nowait, sentinel)
_ = loop.call_soon_threadsafe(queue.put_nowait, sentinel)
# Currently we support chat-completion tasks only.
task_data = task.task_data
@@ -91,15 +93,16 @@ async def _mlx_generate(
if isinstance(item, Exception):
raise item
assert isinstance(item, GenerationResponse) # constrain datatype
assert isinstance(item, GenerationResponse) # constrain datatype
yield item
assert future.done()
async def main():
try:
runner_print('hello from the runner')
runner_print("hello from the runner")
# Get setup info from worker
init_message: RunnerMessage = await runner_read_message()
@@ -107,10 +110,12 @@ async def main():
model_shard_meta: ShardMeta = setup_message.model_shard_meta
hosts: list[Host] = setup_message.hosts
mlx_executor: ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
mlx_executor: ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor(
max_workers=1
)
loop: AbstractEventLoop = asyncio.get_running_loop()
runner_print(f'got here; {model_shard_meta.model_path}')
runner_print(f"got here; {model_shard_meta.model_path}")
model, tokenizer, sampler = await loop.run_in_executor(
mlx_executor,
@@ -137,7 +142,7 @@ async def main():
task=task_data,
):
runner_write_response(generation_response)
runner_write_response(FinishedResponse())
case ExitMessage():
break
@@ -147,5 +152,6 @@ async def main():
except Exception as e:
runner_write_error(e)
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -2,7 +2,7 @@ import asyncio
import contextlib
import sys
from collections.abc import AsyncGenerator
from typing import Callable
from typing import Any, Callable
from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData
from shared.types.tasks.common import Task, TaskStatusType, TaskType
@@ -17,8 +17,7 @@ from shared.types.worker.commands_runner import (
SetupMessage,
)
from shared.types.worker.mlx import Host
from shared.types.worker.runners import RunnerError
from shared.types.worker.shards import ShardMeta
from shared.types.worker.shards import ShardMetadata
from worker.runner.communication import (
supervisor_read_response,
supervisor_write_message,
@@ -31,25 +30,27 @@ class RunnerSupervisor:
RunnerSupervisor manages the lifecycle of a runner subprocess for model inference.
Use the class method `create` to properly initialize an instance.
"""
def __init__(
self,
model_shard_meta: ShardMeta,
model_shard_meta: ShardMetadata[Any],
hosts: list[Host],
runner_process: asyncio.subprocess.Process,
):
"""Private constructor. Use RunnerSupervisor.create() instead."""
self.model_shard_meta: ShardMeta = model_shard_meta
self.model_shard_meta: ShardMetadata[Any] = model_shard_meta
self.hosts: list[Host] = hosts
self.runner_process: asyncio.subprocess.Process = runner_process
self.running: bool = True
self.running_task: asyncio.Task[None] = asyncio.create_task(self._watch_runner())
self.running_task: asyncio.Task[None] = asyncio.create_task(
self._watch_runner()
)
@classmethod
async def create(
cls,
model_shard_meta: ShardMeta,
model_shard_meta: ShardMetadata[Any],
hosts: list[Host],
) -> "RunnerSupervisor":
"""
@@ -57,12 +58,14 @@ class RunnerSupervisor:
The .create() classmethod pattern is used to ensure the constructor is asynchronous.
"""
cmd: list[str] = get_runner_command()
runner_process: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr,
runner_process: asyncio.subprocess.Process = (
await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr,
)
)
await supervisor_write_message(
@@ -91,7 +94,9 @@ class RunnerSupervisor:
if self.runner_process.stdout is not None:
while True:
try:
line = await asyncio.wait_for(self.runner_process.stdout.readline(), timeout=0.01)
line = await asyncio.wait_for(
self.runner_process.stdout.readline(), timeout=0.01
)
if not line:
break
print(f"Remaining stdout: {line.decode('utf-8').strip()}")
@@ -100,7 +105,9 @@ class RunnerSupervisor:
try:
# Give the process a moment to exit gracefully
await supervisor_write_message(proc=self.runner_process, message=ExitMessage())
await supervisor_write_message(
proc=self.runner_process, message=ExitMessage()
)
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
except asyncio.TimeoutError:
print("Runner process did not terminate, killing...")
@@ -114,7 +121,9 @@ class RunnerSupervisor:
def __del__(self) -> None:
if not self.running:
print('Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process.')
print(
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process."
)
with contextlib.suppress(ProcessLookupError):
self.runner_process.kill()
@@ -150,12 +159,16 @@ class RunnerSupervisor:
)
while True:
line: RunnerResponse | None = await supervisor_read_response(self.runner_process)
line: RunnerResponse | None = await supervisor_read_response(
self.runner_process
)
if line is None:
continue
else:
match line:
case GenerationResponse(text=text, token=token, finish_reason=finish_reason):
case GenerationResponse(
text=text, token=token, finish_reason=finish_reason
):
yield TokenChunk(
task_id=task.task_id,
idx=token,
@@ -169,7 +182,11 @@ class RunnerSupervisor:
case FinishedResponse():
break
case PrintResponse(text=text):
print(f'runner printed: {text}')
case ErrorResponse(error_type=error_type, error_message=error_message, traceback=traceback):
print(f"runner printed: {text}")
case ErrorResponse(
error_type=error_type,
error_message=error_message,
traceback=traceback,
):
await self.astop()
raise RunnerError(error_type, error_message, traceback or "")
raise Exception(error_type, error_message, traceback or "")

View File

@@ -3,6 +3,4 @@ import sys
def get_runner_command() -> list[str]:
python = sys.executable
return [
python, '-m', 'worker.runner.runner'
]
return [python, "-m", "worker.runner.runner"]

View File

@@ -3,48 +3,69 @@ from pathlib import Path
from typing import Callable, cast
import pytest
from openai.types.chat import ChatCompletionUserMessageParam
from openai.types.chat.completion_create_params import (
CompletionCreateParamsNonStreaming,
CompletionCreateParamsStreaming,
)
from pydantic import TypeAdapter
from shared.types.models.common import ModelId
from shared.types.tasks.common import (
ChatCompletionMessage,
ChatCompletionParams,
ChatCompletionStreamingTask,
PendingTaskStatus,
Task,
TaskArtifact,
TaskId,
TaskState,
TaskStatusIncompleteType,
TaskStatusOtherType,
TaskStatusType,
TaskType,
)
from shared.types.worker.common import InstanceId
from shared.types.worker.mlx import Host
from shared.types.worker.shards import PipelineShardMeta
from shared.types.worker.shards import PipelineShardMetadata
CompletionCreateParamsStreamingAdapter = TypeAdapter(CompletionCreateParamsStreaming)
CompletionCreateParamsNonStreamingAdapter = TypeAdapter(
CompletionCreateParamsNonStreaming
)
# Concrete TaskArtifact implementation for pending streaming tasks
class PendingStreamingTaskArtifact(TaskArtifact[TaskType.ChatCompletionStreaming, TaskStatusIncompleteType.Pending]):
class PendingStreamingTaskArtifact(
TaskArtifact[TaskType.ChatCompletionStreaming, TaskStatusOtherType.Pending]
):
pass
@pytest.fixture
def pipeline_shard_meta():
def _pipeline_shard_meta(num_nodes: int = 1, device_rank: int = 0) -> PipelineShardMeta:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = 16
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = start_layer + layers_per_node if device_rank < num_nodes - 1 else total_layers
return PipelineShardMeta(
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
device_rank=device_rank,
model_id=ModelId(uuid=uuid.uuid4()),
model_path=Path("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit/").expanduser(),
model_path=Path(
"~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit/"
).expanduser(),
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta
@pytest.fixture
def hosts():
def _hosts(count: int, offset: int = 0) -> list[Host]:
@@ -55,51 +76,57 @@ def hosts():
)
for i in range(count)
]
return _hosts
@pytest.fixture
def hosts_one(hosts: Callable[[int], list[Host]]):
return hosts(1)
@pytest.fixture
def hosts_two(hosts: Callable[[int], list[Host]]):
return hosts(2)
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Hello, how are you?"
@pytest.fixture
def chat_completion_params(user_message: str):
"""Creates ChatCompletionParams with the given message"""
return ChatCompletionParams(
return CompletionCreateParamsStreaming(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user",
content=user_message
)
],
stream=True
messages=[ChatCompletionUserMessageParam(role="user", content=user_message)],
stream=True,
)
@pytest.fixture
def chat_completion_streaming_task_data(chat_completion_params: ChatCompletionParams):
def chat_completion_streaming_task_data(
chat_completion_params: CompletionCreateParamsStreaming,
):
"""Creates ChatCompletionStreamingTask from params"""
return ChatCompletionStreamingTask(
task_data=chat_completion_params
)
return ChatCompletionStreamingTask(task_data=chat_completion_params)
@pytest.fixture
def streaming_task(chat_completion_streaming_task_data: ChatCompletionStreamingTask) -> Task[TaskType, TaskStatusType]:
def streaming_task(
chat_completion_streaming_task_data: CompletionCreateParamsStreaming,
) -> Task[TaskType, TaskStatusType]:
"""Creates the final Task object"""
task = Task(
task_id=TaskId(),
task_type=TaskType.ChatCompletionStreaming,
task_data=chat_completion_streaming_task_data,
task_params=ChatCompletionStreamingTask(
task_data=chat_completion_streaming_task_data
),
task_state=TaskState(
task_status=PendingTaskStatus(),
task_status=TaskStatusOtherType.Pending,
task_artifact=PendingStreamingTaskArtifact(),
),
on_instance=InstanceId(),

View File

@@ -2,31 +2,41 @@ from typing import Callable, Literal, TypeVar
from pydantic import BaseModel, TypeAdapter
from shared.types.tasks.common import Task, TaskStatusIncompleteType, TaskType
from shared.types.tasks.common import Task, TaskStatusOtherType, TaskType
from shared.types.worker.commands_runner import (
ChatTaskMessage,
RunnerMessageTypeAdapter,
SetupMessage,
)
from shared.types.worker.mlx import Host
from shared.types.worker.shards import PipelineShardMeta
from shared.types.worker.shards import PipelineShardMetadata
T = TypeVar("T", bound=BaseModel)
T = TypeVar('T', bound=BaseModel)
def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]):
encoded: bytes = obj.model_dump_json().encode('utf-8') + b'\n'
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
decoded: T = typeadapter.validate_json(encoded)
assert decoded == obj, f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}"
assert decoded == obj, (
f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}"
)
def test_supervisor_setup_message_serdes(pipeline_shard_meta: Callable[..., PipelineShardMeta], hosts: Callable[..., list[Host]]):
def test_supervisor_setup_message_serdes(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
setup_message = SetupMessage(
model_shard_meta=pipeline_shard_meta(1, 0),
hosts=hosts(1),
)
assert_equal_serdes(setup_message, RunnerMessageTypeAdapter)
def test_supervisor_task_message_serdes(streaming_task: Task[TaskType, Literal[TaskStatusIncompleteType.Pending]]):
def test_supervisor_task_message_serdes(
streaming_task: Task[TaskType, Literal[TaskStatusOtherType.Pending]],
):
task_message = ChatTaskMessage(
task=streaming_task.task_data,
)

View File

@@ -34,7 +34,7 @@ async def test_supervisor_single_node_response(
try:
full_response = ""
stop_reason: FinishReason | None = None
async for chunk in supervisor.stream_response(task=streaming_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.chunk_data.text
@@ -42,12 +42,15 @@ async def test_supervisor_single_node_response(
stop_reason = chunk.chunk_data.finish_reason
# Case-insensitive check for Paris in the response
assert "paris" in full_response.lower(), f"Expected 'Paris' in response, but got: {full_response}"
assert stop_reason == 'stop'
assert "paris" in full_response.lower(), (
f"Expected 'Paris' in response, but got: {full_response}"
)
assert stop_reason == "stop"
finally:
await supervisor.astop()
@pytest.mark.asyncio
async def test_supervisor_two_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMeta],
@@ -70,33 +73,38 @@ async def test_supervisor_two_node_response(
try:
full_response_0 = ""
full_response_1 = ""
async def collect_response_0():
nonlocal full_response_0
async for chunk in supervisor_0.stream_response(task=streaming_task):
if isinstance(chunk, TokenChunk):
full_response_0 += chunk.chunk_data.text
async def collect_response_1():
nonlocal full_response_1
async for chunk in supervisor_1.stream_response(task=streaming_task):
if isinstance(chunk, TokenChunk):
full_response_1 += chunk.chunk_data.text
# Run both stream responses simultaneously
_ = await asyncio.gather(collect_response_0(), collect_response_1())
print(f"full_response_0: {full_response_0}")
print(f"full_response_1: {full_response_1}")
# Case-insensitive check for Paris in both responses
assert "paris" in full_response_0.lower(), f"Expected 'Paris' in response, but got: {full_response_0}"
assert "paris" in full_response_1.lower(), f"Expected 'Paris' in response, but got: {full_response_1}"
assert "paris" in full_response_0.lower(), (
f"Expected 'Paris' in response, but got: {full_response_0}"
)
assert "paris" in full_response_1.lower(), (
f"Expected 'Paris' in response, but got: {full_response_1}"
)
finally:
await supervisor_0.astop()
await supervisor_1.astop()
@pytest.mark.asyncio
async def test_supervisor_early_stopping(
pipeline_shard_meta: Callable[..., PipelineShardMeta],
@@ -115,8 +123,10 @@ async def test_supervisor_early_stopping(
try:
streaming_task.task_data.task_data.max_tokens = max_tokens
streaming_task.task_data.task_data.messages[0].content = "Please count from 1 to 100"
streaming_task.task_data.task_data.messages[
0
].content = "Please count from 1 to 100"
full_response = ""
count = 0
stop_reason: FinishReason | None = None
@@ -127,14 +137,14 @@ async def test_supervisor_early_stopping(
count += 1
if chunk.chunk_data.finish_reason:
stop_reason = chunk.chunk_data.finish_reason
print(f"full_response: {full_response}")
assert count == max_tokens + 1
assert '7' in full_response.lower()
assert '99' not in full_response.lower()
assert "7" in full_response.lower()
assert "99" not in full_response.lower()
assert stop_reason == 'length'
assert stop_reason == "length"
finally:
await supervisor.astop()