diff --git a/engines/mlx/auto_parallel.py b/engines/mlx/auto_parallel.py new file mode 100644 index 00000000..3b8531bb --- /dev/null +++ b/engines/mlx/auto_parallel.py @@ -0,0 +1,114 @@ +from typing import Protocol, cast, override + +import mlx.core as mx +import mlx.nn as nn + +from shared.types.worker.shards import PipelineShardMetadata + + +class IdentityLayer(nn.Module): + @override + def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: + return x + + +class _LayerCallable(Protocol): + """Structural type that any compatible layer must satisfy. + + We require a single positional input of type ``mx.array`` and an + ``mx.array`` output, while permitting arbitrary *args / **kwargs so this + protocol matches the vast majority of `mlx.nn.Module` subclasses. + """ + + def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ... + + +class PipelineFirstLayer(nn.Module): + def __init__(self, original_layer: _LayerCallable, r: int, s: int): + super().__init__() + self.original_layer: _LayerCallable = original_layer + self.r: int = r + self.s: int = s + + @override + def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: + if self.r != 0: + x = mx.distributed.recv_like(x, (self.r - 1)) + return self.original_layer(x, *args, **kwargs) + + +class PipelineLastLayer(nn.Module): + def __init__(self, original_layer: _LayerCallable, r: int, s: int): + super().__init__() + self.original_layer: _LayerCallable = original_layer + self.r: int = r + self.s: int = s + + @override + def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: + output: mx.array = self.original_layer(x, *args, **kwargs) + if self.r != self.s - 1: + output = mx.distributed.send(output, (self.r + 1) % self.s) + output = mx.distributed.all_gather(output)[-output.shape[0] :] # pyright: ignore[reportUnknownMemberType] + return output + + +def inner_model(model: nn.Module) -> nn.Module: + inner = getattr(model, "model", None) + if isinstance(inner, nn.Module): + return inner + + inner = getattr(model, "transformer", None) + if isinstance(inner, nn.Module): + return inner + + raise ValueError("Model must either have a 'model' or 'transformer' attribute") + + +# def auto_parallel(model: nn.Module, rank: int, size: int, start_layer: int, end_layer: int) -> nn.Module: +def auto_parallel( + model: nn.Module, model_shard_meta: PipelineShardMetadata +) -> nn.Module: + """ + Automatically parallelize a model across multiple devices. + + Args: + model: The model to parallelize (must have a 'layers' or 'h' property) + model_shard_meta: The metadata for the model shard + + Returns: + The parallelized model + """ + + inner_model_instance: nn.Module = inner_model(model) + + # Handle both model.layers and model.h cases + layers: list[_LayerCallable] + if hasattr(inner_model_instance, "layers"): + layers = cast(list[_LayerCallable], inner_model_instance.layers) + else: + layers = cast(list[_LayerCallable], inner_model_instance.h) + + layers[: model_shard_meta.start_layer] = [ + IdentityLayer() for _ in range(model_shard_meta.start_layer) + ] + layers[model_shard_meta.end_layer :] = [ + IdentityLayer() for _ in range(len(layers) - model_shard_meta.end_layer) + ] + layers[model_shard_meta.start_layer] = PipelineFirstLayer( + layers[model_shard_meta.start_layer], + model_shard_meta.device_rank, + model_shard_meta.world_size, + ) + layers[model_shard_meta.end_layer - 1] = PipelineLastLayer( + layers[model_shard_meta.end_layer - 1], + model_shard_meta.device_rank, + model_shard_meta.world_size, + ) + + # At this point `layers` *must* be a concrete list. + assert isinstance(layers, list), ( + "Expected a list of layers after auto-parallel initialisation" + ) + + return model diff --git a/shared/mlx/utils_mlx.py b/engines/mlx/utils_mlx.py similarity index 85% rename from shared/mlx/utils_mlx.py rename to engines/mlx/utils_mlx.py index 397593d3..5de40e63 100644 --- a/shared/mlx/utils_mlx.py +++ b/engines/mlx/utils_mlx.py @@ -21,15 +21,20 @@ from shared.types.worker.shards import ShardMeta from worker.runner.communication import runner_print -def mx_barrier(): - mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu)))) # type: ignore +def mx_barrier(): + mx.eval( + mx.distributed.all_sum( + mx.array(1.0), stream=mx.default_stream(mx.Device(mx.cpu)) + ) + ) + class HostList(RootModel[list[str]]): - @classmethod def from_hosts(cls, hosts: list[Host]) -> "HostList": return cls(root=[str(host) for host in hosts]) + def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: """ Initialize the MLX distributed (runs in thread pool) @@ -37,10 +42,10 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: runner_print(f"Starting initialization for rank {rank}") # Setup distributed environment - hostfile = f"./hosts_{rank}.json" # TODO: this needs to be unique? + hostfile = f"./hosts_{rank}.json" # TODO: this needs to be unique? hosts_json = HostList.from_hosts(hosts).model_dump_json() - runner_print(f'rank {rank} hostfile: {hostfile} hosts: {hosts_json}') + runner_print(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}") with open(hostfile, "w") as f: _ = f.write(hosts_json) @@ -55,6 +60,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: return group + def initialize_mlx( model_shard_meta: ShardMeta, hosts: list[Host], @@ -71,8 +77,9 @@ def initialize_mlx( return model, tokenizer, sampler + def shard_and_load(model_shard_meta: ShardMeta) -> tuple[nn.Module, TokenizerWrapper]: - runner_print(f'loading model from {model_shard_meta.model_path}') + runner_print(f"loading model from {model_shard_meta.model_path}") model, config = load_model(model_shard_meta.model_path, lazy=True, strict=False) @@ -102,9 +109,11 @@ async def apply_chat_template( for message in messages_dicts: filtered_message = {k: v for k, v in message.items() if v is not None} # Verify we have exactly the expected keys - assert set(filtered_message.keys()) == {'role', 'content'}, f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}" + assert set(filtered_message.keys()) == {"role", "content"}, ( + f"Expected only 'role' and 'content' keys, got: {filtered_message.keys()}" + ) formatted_messages.append(filtered_message) - + messages_dicts = formatted_messages prompt: str = await loop.run_in_executor( @@ -113,7 +122,7 @@ async def apply_chat_template( messages_dicts, tokenize=False, add_generation_prompt=True, - ) + ), ) - return prompt \ No newline at end of file + return prompt diff --git a/master/logging.py b/master/logging.py index 36ee3a1b..40d6812d 100644 --- a/master/logging.py +++ b/master/logging.py @@ -26,10 +26,22 @@ class MasterInvalidCommandReceivedLogEntry( command_name: str -class MasterCommandRunnerNotRunningLogEntry: ... +class MasterCommandRunnerNotRunningLogEntry( + LogEntry[Literal["master_command_runner_not_running"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_command_runner_not_running"] = ( + "master_command_runner_not_running" + ) + message: str = "Command Runner Not Running" -class MasterStateManagerStoppedLogEntry: ... +class MasterStateManagerStoppedLogEntry( + LogEntry[Literal["master_state_manager_stopped"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_state_manager_stopped"] = "master_state_manager_stopped" + message: str = "State Manager Stopped" class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]): diff --git a/master/main.py b/master/main.py index 58b1d20e..0a395b69 100644 --- a/master/main.py +++ b/master/main.py @@ -1,23 +1,16 @@ -from asyncio import CancelledError, Lock, Task, create_task -from asyncio import Queue as AsyncQueue from contextlib import asynccontextmanager from logging import Logger, LogRecord from queue import Queue as PQueue -from typing import Callable, Sequence +from typing import Literal -from fastapi import FastAPI, Response -from fastapi.responses import StreamingResponse +from fastapi import FastAPI -from master.commands import ExternalCommand from master.env import MasterEnvironmentSchema from master.logging import ( - MasterCommandRunnerNotRunningLogEntry, - MasterStateManagerStoppedLogEntry, MasterUninitializedLogEntry, ) -from master.router import QueueMapping -from master.state_manager.sync import SyncStateManagerMapping from shared.constants import EXO_MASTER_STATE +from shared.event_loops.main import NodeEventLoopProtocol from shared.logger import ( FilterLogByType, LogEntryType, @@ -27,11 +20,7 @@ from shared.logger import ( log, ) from shared.types.events.common import ( - Apply, - EventCategory, - EventFromEventLog, - EventPublisher, - State, + EventCategoryEnum, ) from shared.types.models.common import ModelId from shared.types.models.model import ModelInfo @@ -63,93 +52,20 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState: return data -# Safety on Apply. -def safely_apply[T: EventCategory]( - state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]] -) -> State[T]: - sorted_events = sorted(events, key=lambda event: event.idx_in_log) - state = state.model_copy() - for event in sorted_events: - if event.idx_in_log <= state.last_event_applied_idx: - continue - state.last_event_applied_idx = event.idx_in_log - state = apply_fn(state, event) - return state +# What The Master Cares About +MasterEventCategories = ( + Literal[EventCategoryEnum.MutatesControlPlaneState] + | Literal[EventCategoryEnum.MutatesTaskState] + | Literal[EventCategoryEnum.MutatesTaskSagaState] + | Literal[EventCategoryEnum.MutatesRunnerStatus] + | Literal[EventCategoryEnum.MutatesInstanceState] + | Literal[EventCategoryEnum.MutatesNodePerformanceState] + | Literal[EventCategoryEnum.MutatesDataPlaneState] +) -class MasterEventLoop: - """Thread-safe manager for MasterState with independent event loop.""" - - def __init__( - self, - initial_state: MasterState, - push_events_to_queue: Callable[[QueueMapping], None], - event_publisher: EventPublisher[EventCategory], - state_managers: SyncStateManagerMapping, - logger: Logger, - ): - self._state = initial_state - self._state_lock = Lock() - self._event_queues: QueueMapping - self._command_runner: ... - self._command_run_task: Task[None] | None = None - self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue() - self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue() - self._state_managers: SyncStateManagerMapping - self._state_global_lock: Lock = Lock() - self._push_events_to_queue: Callable[[QueueMapping], None] - self._event_fetch_task: Task[None] | None = None - self._logger = logger - - @property - def _is_command_runner_running(self) -> bool: - return self._command_run_task is not None and not self._command_run_task.done() - - @property - def _is_event_fetcher_running(self) -> bool: - return self._event_fetch_task is not None and not self._event_fetch_task.done() - - async def send_command( - self, command: ExternalCommand - ) -> Response | StreamingResponse: - """Send a command to the background event loop.""" - if self._is_command_runner_running: - await self._command_queue.put(command) - return await self._response_queue.get() - else: - log(self._logger, MasterCommandRunnerNotRunningLogEntry()) - raise RuntimeError("Command Runner Is Not Running") - - async def start(self) -> None: - """Start the background event loop.""" - - async def fetch_and_apply_events() -> None: - while True: - async with self._state_global_lock: - for state in self._state_managers.values(): - self._push_events_to_queue(self._event_queues) - safely_apply( - state, apply_fn, self._event_queues[state.event_category] - ) - - self._event_fetch_task = create_task(fetch_and_apply_events()) - self._command_run_task = create_task(self._command_runner()) - - async def stop(self) -> None: - """Stop the background event loop and persist state.""" - if not self._is_command_runner_running or not self._is_event_fetcher_running: - raise RuntimeError("Command Runner Is Not Running") - - assert self._command_run_task is not None and self._event_fetch_task is not None - - for service in [self._event_fetch_task, self._command_run_task]: - service.cancel() - try: - await service - except CancelledError: - pass - - log(self._logger, MasterStateManagerStoppedLogEntry()) +# Takes Care Of All States And Events Related To The Master +class MasterEventLoopProtocol(NodeEventLoopProtocol[MasterEventCategories]): ... @asynccontextmanager @@ -182,7 +98,7 @@ async def lifespan(app: FastAPI): cluster_queue, ) - # TODO: Add handlers + # TODO: Add Handlers For Pushing Logs To Remote Services telemetry_listener = create_queue_listener(telemetry_queue, []) metrics_listener = create_queue_listener(metrics_queue, []) cluster_listener = create_queue_listener(cluster_queue, []) @@ -191,15 +107,13 @@ async def lifespan(app: FastAPI): metrics_listener.start() cluster_listener.start() - initial_state = get_master_state(logger) - app.state.master_event_loop = MasterEventLoop( - initial_state, None, None, None, logger - ) - await app.state.master_event_loop.start() + # initial_state = get_master_state(logger) + # app.state.master_event_loop = MasterEventLoop() + # await app.state.master_event_loop.start() yield - await app.state.master_event_loop.stop() + # await app.state.master_event_loop.stop() app = FastAPI(lifespan=lifespan) diff --git a/master/router.py b/master/router.py deleted file mode 100644 index 196896a8..00000000 --- a/master/router.py +++ /dev/null @@ -1,90 +0,0 @@ -from asyncio import Queue, gather -from logging import Logger -from typing import Literal, Protocol, TypedDict - -from master.sanity_checking import check_keys_in_map_match_enum_values -from shared.types.events.common import ( - EventCategories, - EventCategory, - EventCategoryEnum, - EventFromEventLog, - narrow_event_from_event_log_type, -) - - -class QueueMapping(TypedDict): - MutatesTaskState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]] - ] - MutatesTaskSagaState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]] - ] - MutatesControlPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]] - ] - MutatesDataPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]] - ] - MutatesRunnerStatus: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]] - ] - MutatesInstanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]] - ] - MutatesNodePerformanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]] - ] - - -check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum) - - -class EventRouterProtocol(Protocol): - queue_map: QueueMapping - start_idx: int - - def sync_queues(self) -> None: ... - - -class EventRouter(EventRouterProtocol): - """Routes events to appropriate services based on event categories.""" - - queue_map: QueueMapping - start_idx: int - logger: Logger - - async def _get_queue_by_category[T: EventCategory]( - self, category: T - ) -> Queue[EventFromEventLog[T]]: - """Get the queue for a given category.""" - category_str: str = category.value - queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str] - return queue - - async def _process_events[T: EventCategory](self, category: T) -> None: - """Process events for a given domain.""" - queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category) - events_to_process: list[EventFromEventLog[T]] = [] - while not queue.empty(): - events_to_process.append(await queue.get()) - for event_to_process in events_to_process: - await self.queue_map[category.value].put(event_to_process) - return None - - async def _submit_events[T: EventCategory | EventCategories]( - self, events: list[EventFromEventLog[T]] - ) -> None: - """Route multiple events to their appropriate services.""" - for event in events: - if isinstance(event.event.event_category, EventCategory): - q1: Queue[EventFromEventLog[T]] = self.queue_map[ - event.event.event_category.value - ] - await q1.put(event) - elif isinstance(event.event.event_category, EventCategories): - for category in event.event.event_category: - narrow_event = narrow_event_from_event_log_type(event, category) - q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value] - await q2.put(narrow_event) - - await gather(*[self._process_events(domain) for domain in EventCategoryEnum]) diff --git a/master/state_manager/async.py b/master/state_manager/async.py index dcddfa25..1fe77663 100644 --- a/master/state_manager/async.py +++ b/master/state_manager/async.py @@ -10,8 +10,8 @@ from master.logging import ( StateUpdateLoopStartedLogEntry, StateUpdateLoopStoppedLogEntry, ) -from master.router import check_keys_in_map_match_enum_values -from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from master.sanity_checking import check_keys_in_map_match_enum_values +from shared.constants import get_error_reporting_message from shared.logger import log from shared.types.events.common import ( Apply, @@ -74,7 +74,7 @@ class AsyncStateManager[EventCategoryT: EventCategory](Protocol): raise RuntimeError("State Update Loop Not Running") assert self._task is not None, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" "BUG: is_running is True but _task is None, this should never happen!" ) self._task.cancel() diff --git a/shared/constants.py b/shared/constants.py index de681821..a69b161a 100644 --- a/shared/constants.py +++ b/shared/constants.py @@ -21,7 +21,8 @@ def get_caller_module_name() -> str: return mod.__name__ -EXO_ERROR_REPORTING_MESSAGE = lambda: ( - f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n" - f"The module that raised the error was: {get_caller_module_name()}" -) +def get_error_reporting_message() -> str: + return ( + f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n" + f"The module that raised the error was: {get_caller_module_name()}" + ) diff --git a/master/commands.py b/shared/event_loops/commands.py similarity index 82% rename from master/commands.py rename to shared/event_loops/commands.py index da83b1ff..ac79b3b8 100644 --- a/master/commands.py +++ b/shared/event_loops/commands.py @@ -2,8 +2,15 @@ from typing import Annotated, Literal from pydantic import BaseModel, Field, TypeAdapter +from shared.types.common import NewUUID + + +class ExternalCommandId(NewUUID): + pass + class BaseExternalCommand[T: str](BaseModel): + command_id: ExternalCommandId command_type: T diff --git a/shared/event_loops/main.py b/shared/event_loops/main.py new file mode 100644 index 00000000..c997028d --- /dev/null +++ b/shared/event_loops/main.py @@ -0,0 +1,121 @@ +from asyncio import Lock, Task +from asyncio import Queue as AsyncQueue +from collections.abc import MutableMapping +from logging import Logger +from typing import Any, Hashable, Mapping, Protocol, Sequence + +from fastapi.responses import Response, StreamingResponse + +from shared.event_loops.commands import ExternalCommand +from shared.types.events.common import Apply, EventCategory, EventFromEventLog, State + + +class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]): + __slots__ = ("_store",) + + required_keys: frozenset[K] = frozenset() + + def __init__(self, data: Mapping[K, V]): + missing = self.required_keys - data.keys() + extra = data.keys() - self.required_keys + if missing or extra: + raise ValueError(f"missing={missing!r}, extra={extra!r}") + self._store: dict[K, V] = dict(data) + + def __getitem__(self, k: K) -> V: + return self._store[k] + + def __setitem__(self, k: K, v: V) -> None: + self._store[k] = v + + def __delitem__(self, k: K) -> None: + del self._store[k] + + def __iter__(self): + return iter(self._store) + + def __len__(self) -> int: + return len(self._store) + + +# Safety on Apply. +def safely_apply[T: EventCategory]( + state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]] +) -> State[T]: + sorted_events = sorted(events, key=lambda event: event.idx_in_log) + state = state.model_copy() + for event in sorted_events: + if event.idx_in_log <= state.last_event_applied_idx: + continue + state.last_event_applied_idx = event.idx_in_log + state = apply_fn(state, event) + return state + + +class NodeCommandLoopProtocol(Protocol): + _command_runner: Task[Any] | None = None + _command_queue: AsyncQueue[ExternalCommand] + _response_queue: AsyncQueue[Response | StreamingResponse] + _logger: Logger + + @property + def is_command_runner_running(self) -> bool: + return self._command_runner is not None and not self._command_runner.done() + + async def start_command_runner(self) -> None: ... + async def stop_command_runner(self) -> None: ... + async def push_command(self, command: ExternalCommand) -> None: ... + async def pop_response(self) -> Response | StreamingResponse: ... + async def _handle_command(self, command: ExternalCommand) -> None: ... + + +class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol): + _event_fetcher: Task[Any] | None = None + _event_queues: ExhaustiveMapping[ + EventCategoryT, AsyncQueue[EventFromEventLog[EventCategory]] + ] + _logger: Logger + + @property + async def is_event_fetcher_running(self) -> bool: + return self._event_fetcher is not None and not self._event_fetcher.done() + + async def start_event_fetcher(self) -> None: ... + async def stop_event_fetcher(self) -> None: ... + + +class NodeStateStorageProtocol[EventCategoryT: EventCategory](Protocol): + _state_managers: ExhaustiveMapping[EventCategoryT, State[EventCategoryT]] + _state_lock: Lock + _logger: Logger + + async def _read_state( + self, event_category: EventCategoryT + ) -> State[EventCategoryT]: ... + + +class NodeStateManagerProtocol[EventCategoryT: EventCategory]( + NodeEventGetterProtocol[EventCategoryT], NodeStateStorageProtocol[EventCategoryT] +): + _state_manager: Task[Any] | None = None + _logger: Logger + + @property + async def is_state_manager_running(self) -> bool: + is_task_running = ( + self._state_manager is not None and not self._state_manager.done() + ) + return ( + is_task_running + and await self.is_event_fetcher_running + and await self.is_state_manager_running + ) + + async def start_state_manager(self) -> None: ... + async def stop_state_manager(self) -> None: ... + async def _apply_queued_events(self) -> None: ... + + +class NodeEventLoopProtocol[EventCategoryT: EventCategory]( + NodeCommandLoopProtocol, NodeStateManagerProtocol[EventCategoryT] +): ... diff --git a/shared/event_loops/router.py b/shared/event_loops/router.py new file mode 100644 index 00000000..3dc27efe --- /dev/null +++ b/shared/event_loops/router.py @@ -0,0 +1,78 @@ +from asyncio.queues import Queue +from typing import Sequence, cast, get_args + +from shared.event_loops.main import ExhaustiveMapping +from shared.types.events.common import ( + EventCategories, + EventCategory, + EventCategoryEnum, + EventFromEventLog, + narrow_event_from_event_log_type, +) + +""" +from asyncio import gather +from logging import Logger +from typing import Literal, Protocol, Sequence, TypedDict + +from master.sanity_checking import check_keys_in_map_match_enum_values +from shared.types.events.common import EventCategoryEnum +""" + +""" +class EventQueues(TypedDict): + MutatesTaskState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]] + ] + MutatesTaskSagaState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]] + ] + MutatesControlPlaneState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]] + ] + MutatesDataPlaneState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]] + ] + MutatesRunnerStatus: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]] + ] + MutatesInstanceState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]] + ] + MutatesNodePerformanceState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]] + ] + + +check_keys_in_map_match_enum_values(EventQueues, EventCategoryEnum) +""" + + +async def route_events[UnionOfRelevantEvents: EventCategory]( + queue_map: ExhaustiveMapping[ + UnionOfRelevantEvents, Queue[EventFromEventLog[EventCategory]] + ], + events: Sequence[EventFromEventLog[EventCategory | EventCategories]], +) -> None: + """Route an event to the appropriate queue.""" + tuple_of_categories: tuple[EventCategoryEnum, ...] = get_args(UnionOfRelevantEvents) + print(tuple_of_categories) + for event in events: + if isinstance(event.event.event_category, EventCategoryEnum): + category: EventCategory = event.event.event_category + if category not in tuple_of_categories: + continue + narrowed_event = narrow_event_from_event_log_type(event, category) + q1: Queue[EventFromEventLog[EventCategory]] = queue_map[ + cast(UnionOfRelevantEvents, category) + ] # TODO: make casting unnecessary + await q1.put(narrowed_event) + else: + for category in event.event.event_category: + if category not in tuple_of_categories: + continue + narrow_event = narrow_event_from_event_log_type(event, category) + q2 = queue_map[ + cast(UnionOfRelevantEvents, category) + ] # TODO: make casting unnecessary + await q2.put(narrow_event) diff --git a/shared/logger.py b/shared/logger.py index 75fb4f29..efe6f66b 100644 --- a/shared/logger.py +++ b/shared/logger.py @@ -4,7 +4,7 @@ from collections.abc import Sequence, Set from queue import Queue from typing import Annotated -from pydantic import Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter from rich.logging import RichHandler from master.logging import MasterLogEntries @@ -28,12 +28,6 @@ class FilterLogByType(logging.Filter): return True -class LogEntryType(str, Enum): - telemetry = "telemetry" - metrics = "metrics" - cluster = "cluster" - - class LogEntry(BaseModel): event_type: Set[LogEntryType] diff --git a/shared/mlx/auto_parallel.py b/shared/mlx/auto_parallel.py deleted file mode 100644 index 987933bf..00000000 --- a/shared/mlx/auto_parallel.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Protocol, cast, override - -import mlx.core as mx -import mlx.nn as nn - -from shared.types.worker.shards import PipelineShardMeta - - -class IdentityLayer(nn.Module): - @override - def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: - return x - -class _LayerCallable(Protocol): - """Structural type that any compatible layer must satisfy. - - We require a single positional input of type ``mx.array`` and an - ``mx.array`` output, while permitting arbitrary *args / **kwargs so this - protocol matches the vast majority of `mlx.nn.Module` subclasses. - """ - - def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ... - -class PipelineFirstLayer(nn.Module): - def __init__(self, original_layer: _LayerCallable, r: int, s: int): - super().__init__() - self.original_layer: _LayerCallable = original_layer - self.r: int = r - self.s: int = s - - @override - def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: - if self.r != 0: - x = mx.distributed.recv_like(x, (self.r - 1)) - return self.original_layer(x, *args, **kwargs) - -class PipelineLastLayer(nn.Module): - def __init__(self, original_layer: _LayerCallable, r: int, s: int): - super().__init__() - self.original_layer: _LayerCallable = original_layer - self.r: int = r - self.s: int = s - - @override - def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: - output: mx.array = self.original_layer(x, *args, **kwargs) - if self.r != self.s - 1: - output = mx.distributed.send(output, (self.r + 1) % self.s) - output = mx.distributed.all_gather(output)[-output.shape[0]:] # pyright: ignore[reportUnknownMemberType] - return output - -def inner_model(model: nn.Module) -> nn.Module: - inner = getattr(model, 'model', None) - if isinstance(inner, nn.Module): - return inner - - inner = getattr(model, 'transformer', None) - if isinstance(inner, nn.Module): - return inner - - raise ValueError("Model must either have a 'model' or 'transformer' attribute") - -# def auto_parallel(model: nn.Module, rank: int, size: int, start_layer: int, end_layer: int) -> nn.Module: -def auto_parallel(model: nn.Module, model_shard_meta: PipelineShardMeta) -> nn.Module: - """ - Automatically parallelize a model across multiple devices. - - Args: - model: The model to parallelize (must have a 'layers' or 'h' property) - model_shard_meta: The metadata for the model shard - - Returns: - The parallelized model - """ - - inner_model_instance: nn.Module = inner_model(model) - - # Handle both model.layers and model.h cases - layers: list[_LayerCallable] - if hasattr(inner_model_instance, 'layers'): - layers = cast(list[_LayerCallable], inner_model_instance.layers) - else: - layers = cast(list[_LayerCallable], inner_model_instance.h) - - layers[:model_shard_meta.start_layer] = [IdentityLayer() for _ in range(model_shard_meta.start_layer)] - layers[model_shard_meta.end_layer:] = [IdentityLayer() for _ in range(len(layers) - model_shard_meta.end_layer)] - layers[model_shard_meta.start_layer] = PipelineFirstLayer(layers[model_shard_meta.start_layer], model_shard_meta.device_rank, model_shard_meta.world_size) - layers[model_shard_meta.end_layer - 1] = PipelineLastLayer(layers[model_shard_meta.end_layer - 1], model_shard_meta.device_rank, model_shard_meta.world_size) - - # At this point `layers` *must* be a concrete list. - assert isinstance(layers, list), "Expected a list of layers after auto-parallel initialisation" - - return model \ No newline at end of file diff --git a/shared/types/events/common.py b/shared/types/events/common.py index 364d256f..a451efda 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/common.py @@ -1,5 +1,6 @@ from enum import Enum, StrEnum from typing import ( + Any, Callable, FrozenSet, Literal, @@ -205,9 +206,7 @@ class StateAndEvent[EventCategoryT: EventCategory](NamedTuple): type EffectHandler[EventCategoryT: EventCategory] = Callable[ [StateAndEvent[EventCategoryT], State[EventCategoryT]], None ] -type EventPublisher[EventCategoryT: EventCategory] = Callable[ - [Event[EventCategoryT]], None -] +type EventPublisher = Callable[[Event[Any]], None] # A component that can publish events @@ -224,7 +223,7 @@ class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol): # A component that can get the effect handler for a saga def get_saga_effect_handler[EventCategoryT: EventCategory]( - saga: Saga[EventCategoryT], event_publisher: EventPublisher[EventCategoryT] + saga: Saga[EventCategoryT], event_publisher: EventPublisher ) -> EffectHandler[EventCategoryT]: def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None: trigger_state, trigger_event = state_and_event @@ -236,7 +235,7 @@ def get_saga_effect_handler[EventCategoryT: EventCategory]( def get_effects_from_sagas[EventCategoryT: EventCategory]( sagas: Sequence[Saga[EventCategoryT]], - event_publisher: EventPublisher[EventCategoryT], + event_publisher: EventPublisher, ) -> Sequence[EffectHandler[EventCategoryT]]: return [get_saga_effect_handler(saga, event_publisher) for saga in sagas] diff --git a/shared/types/events/events.py b/shared/types/events/events.py index 0a00dd6c..aabd081b 100644 --- a/shared/types/events/events.py +++ b/shared/types/events/events.py @@ -41,10 +41,10 @@ from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType MLXEvent = Event[ frozenset( - { + ( EventCategoryEnum.MutatesTaskState, EventCategoryEnum.MutatesControlPlaneState, - } + ) ) ] TaskEvent = Event[EventCategoryEnum.MutatesTaskState] diff --git a/shared/types/events/registry.py b/shared/types/events/registry.py index 5fa1f4f7..299b42ee 100644 --- a/shared/types/events/registry.py +++ b/shared/types/events/registry.py @@ -3,7 +3,7 @@ from typing import Annotated, Any, Mapping, Type, get_args from pydantic import Field, TypeAdapter -from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from shared.constants import get_error_reporting_message from shared.types.events.common import ( ControlPlaneEventTypes, DataPlaneEventTypes, @@ -50,7 +50,6 @@ class EventTypeNames(StrEnum): check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames) """ - EventRegistry: Mapping[EventTypes, Type[Any]] = { TaskEventTypes.TaskCreated: TaskCreated, TaskEventTypes.TaskStateUpdated: TaskStateUpdated, @@ -78,7 +77,7 @@ def check_registry_has_all_event_types() -> None: missing_event_types = set(event_types) - set(EventRegistry.keys()) assert not missing_event_types, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"There's an event missing from the registry: {missing_event_types}" ) @@ -91,14 +90,14 @@ def check_union_of_all_events_is_consistent_with_registry( missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union assert not missing_from_union, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"Event classes in registry are missing from all_events union: {missing_from_union}" ) extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry assert not extra_in_union, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"Event classes in all_events union are missing from registry: {extra_in_union}" ) diff --git a/shared/types/events/sanity_checking.py b/shared/types/events/sanity_checking.py index a6413b52..ca489f23 100644 --- a/shared/types/events/sanity_checking.py +++ b/shared/types/events/sanity_checking.py @@ -2,7 +2,7 @@ from enum import Enum, StrEnum from types import UnionType from typing import Any, LiteralString, Sequence, Set, Type, get_args -from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from shared.constants import get_error_reporting_message def check_event_type_union_is_consistent_with_registry( @@ -20,7 +20,7 @@ def check_event_type_union_is_consistent_with_registry( for tag_of_event_type in event_types_inferred_from_registry: event_type = type(tag_of_event_type) assert event_type in event_types_inferred_from_union, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"There's a mismatch between the registry of event types and the union of possible event types." f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}." ) @@ -36,7 +36,7 @@ def check_event_categories_are_defined_for_all_event_types( ] tag_of_event_categories: list[str] = list(event_categories.__members__.values()) assert tag_of_event_categories == expected_category_tags, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"The values of the enum EventCategories are not named after the event type enums." f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}" f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}" @@ -61,7 +61,7 @@ def assert_literal_union_covers_enum[TEnum: StrEnum]( literal_values: Set[Any] = _flatten(literal_union) assert enum_values == literal_values, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"{get_error_reporting_message()}" f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n" f"These are the missing values: {enum_values - literal_values}\n" f"These are the extra values: {literal_values - enum_values}\n" diff --git a/shared/types/states/worker.py b/shared/types/states/worker.py index a57dcd06..dfddc265 100644 --- a/shared/types/states/worker.py +++ b/shared/types/states/worker.py @@ -2,14 +2,14 @@ from collections.abc import Mapping from shared.types.common import NodeId from shared.types.events.common import ( - EventCategory, + EventCategoryEnum, State, ) from shared.types.states.shared import SharedState from shared.types.worker.common import NodeStatus -class NodeStatusState(State[EventCategory.MutatesControlPlaneState]): +class NodeStatusState(State[EventCategoryEnum.MutatesControlPlaneState]): node_status: Mapping[NodeId, NodeStatus] diff --git a/shared/types/tasks/common.py b/shared/types/tasks/common.py index b1aa8a6b..2b422d6e 100644 --- a/shared/types/tasks/common.py +++ b/shared/types/tasks/common.py @@ -29,7 +29,7 @@ class ChatCompletionNonStreamingTask(TaskParams[TaskType.ChatCompletionNonStream task_type: Literal[TaskType.ChatCompletionNonStreaming] = ( TaskType.ChatCompletionNonStreaming ) - task_data: openai.completion_create_params.CompletionCreateParams + task_data: openai.completion_create_params.CompletionCreateParamsNonStreaming @final @@ -37,7 +37,7 @@ class ChatCompletionStreamingTask(TaskParams[TaskType.ChatCompletionStreaming]): task_type: Literal[TaskType.ChatCompletionStreaming] = ( TaskType.ChatCompletionStreaming ) - task_data: openai.completion_create_params.CompletionCreateParams + task_data: openai.completion_create_params.CompletionCreateParamsStreaming @final @@ -83,7 +83,7 @@ class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel) class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel): task_type: TaskTypeT task_params: TaskParams[TaskTypeT] - task_stats: TaskState[TaskStatusTypeT, TaskTypeT] + task_state: TaskState[TaskStatusTypeT, TaskTypeT] on_instance: InstanceId diff --git a/uv.lock b/uv.lock index 866aa987..015412d4 100644 --- a/uv.lock +++ b/uv.lock @@ -44,11 +44,31 @@ wheels = [ [[package]] name = "certifi" -version = "2025.6.15" +version = "2025.7.14" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/73/f7/f14b46d4bcd21092d7d3ccef689615220d8a08fb25e564b65d20738e672e/certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b", size = 158753, upload-time = "2025-06-15T02:45:51.329Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/76/52c535bcebe74590f296d6c77c86dabf761c41980e1347a2422e4aa2ae41/certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995", size = 163981, upload-time = "2025-07-14T03:29:28.449Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650, upload-time = "2025-06-15T02:45:49.977Z" }, + { url = "https://files.pythonhosted.org/packages/4f/52/34c6cf5bb9285074dc3531c437b3919e825d976fde097a7a73f79e726d03/certifi-2025.7.14-py3-none-any.whl", hash = "sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2", size = 162722, upload-time = "2025-07-14T03:29:26.863Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/12/a93df3366ed32db1d907d7593a94f1fe6293903e3e92967bebd6950ed12c/charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0", size = 199622, upload-time = "2025-05-02T08:32:56.363Z" }, + { url = "https://files.pythonhosted.org/packages/04/93/bf204e6f344c39d9937d3c13c8cd5bbfc266472e51fc8c07cb7f64fcd2de/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf", size = 143435, upload-time = "2025-05-02T08:32:58.551Z" }, + { url = "https://files.pythonhosted.org/packages/22/2a/ea8a2095b0bafa6c5b5a55ffdc2f924455233ee7b91c69b7edfcc9e02284/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e", size = 153653, upload-time = "2025-05-02T08:33:00.342Z" }, + { url = "https://files.pythonhosted.org/packages/b6/57/1b090ff183d13cef485dfbe272e2fe57622a76694061353c59da52c9a659/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1", size = 146231, upload-time = "2025-05-02T08:33:02.081Z" }, + { url = "https://files.pythonhosted.org/packages/e2/28/ffc026b26f441fc67bd21ab7f03b313ab3fe46714a14b516f931abe1a2d8/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c", size = 148243, upload-time = "2025-05-02T08:33:04.063Z" }, + { url = "https://files.pythonhosted.org/packages/c0/0f/9abe9bd191629c33e69e47c6ef45ef99773320e9ad8e9cb08b8ab4a8d4cb/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691", size = 150442, upload-time = "2025-05-02T08:33:06.418Z" }, + { url = "https://files.pythonhosted.org/packages/67/7c/a123bbcedca91d5916c056407f89a7f5e8fdfce12ba825d7d6b9954a1a3c/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0", size = 145147, upload-time = "2025-05-02T08:33:08.183Z" }, + { url = "https://files.pythonhosted.org/packages/ec/fe/1ac556fa4899d967b83e9893788e86b6af4d83e4726511eaaad035e36595/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b", size = 153057, upload-time = "2025-05-02T08:33:09.986Z" }, + { url = "https://files.pythonhosted.org/packages/2b/ff/acfc0b0a70b19e3e54febdd5301a98b72fa07635e56f24f60502e954c461/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff", size = 156454, upload-time = "2025-05-02T08:33:11.814Z" }, + { url = "https://files.pythonhosted.org/packages/92/08/95b458ce9c740d0645feb0e96cea1f5ec946ea9c580a94adfe0b617f3573/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b", size = 154174, upload-time = "2025-05-02T08:33:13.707Z" }, + { url = "https://files.pythonhosted.org/packages/78/be/8392efc43487ac051eee6c36d5fbd63032d78f7728cb37aebcc98191f1ff/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148", size = 149166, upload-time = "2025-05-02T08:33:15.458Z" }, + { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] [[package]] @@ -173,6 +193,20 @@ requires-dist = [ { name = "mlx-lm", specifier = ">=0.25.3" }, ] +[[package]] +name = "fastapi" +version = "0.116.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "starlette", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -184,11 +218,11 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.5.1" +version = "2025.7.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/00/f7/27f15d41f0ed38e8fcc488584b57e902b331da7f7c6dcda53721b15838fc/fsspec-2025.5.1.tar.gz", hash = "sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475", size = 303033, upload-time = "2025-05-24T12:03:23.792Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432, upload-time = "2025-07-15T16:05:21.19Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/61/78c7b3851add1481b048b5fdc29067397a1784e2910592bc81bb3f608635/fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462", size = 199052, upload-time = "2025-05-24T12:03:21.66Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597, upload-time = "2025-07-15T16:05:19.529Z" }, ] [[package]] @@ -244,7 +278,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.33.2" +version = "0.33.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -256,69 +290,9 @@ dependencies = [ { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4b/9e/9366b7349fc125dd68b9d384a0fea84d67b7497753fe92c71b67e13f47c4/huggingface_hub-0.33.4.tar.gz", hash = "sha256:6af13478deae120e765bfd92adad0ae1aec1ad8c439b46f23058ad5956cbca0a", size = 426674, upload-time = "2025-07-11T12:32:48.694Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "fastapi" -version = "0.116.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "starlette", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/20/38/e1da78736143fd885c36213a3ccc493c384ae8fea6a0f0bc272ef42ebea8/fastapi-0.116.0.tar.gz", hash = "sha256:80dc0794627af0390353a6d1171618276616310d37d24faba6648398e57d687a", size = 296518, upload-time = "2025-07-07T15:09:27.82Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/68/d80347fe2360445b5f58cf290e588a4729746e7501080947e6cdae114b1f/fastapi-0.116.0-py3-none-any.whl", hash = "sha256:fdcc9ed272eaef038952923bef2b735c02372402d1203ee1210af4eea7a78d2b", size = 95625, upload-time = "2025-07-07T15:09:26.348Z" }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, + { url = "https://files.pythonhosted.org/packages/46/7b/98daa50a2db034cab6cd23a3de04fa2358cb691593d28e9130203eb7a805/huggingface_hub-0.33.4-py3-none-any.whl", hash = "sha256:09f9f4e7ca62547c70f8b82767eefadd2667f4e116acba2e3e62a5a81815a7bb", size = 515339, upload-time = "2025-07-11T12:32:46.346Z" }, ] [[package]] @@ -339,6 +313,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + [[package]] name = "jiter" version = "0.10.0" @@ -446,7 +432,7 @@ wheels = [ [[package]] name = "mlx-lm" -version = "0.25.3" +version = "0.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -454,11 +440,11 @@ dependencies = [ { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "transformers", extra = ["sentencepiece"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/bc/0c3f69a8ff78fc8152985be99b2f83dc7e902b9b96ff5260c6a4958c10f1/mlx_lm-0.25.3.tar.gz", hash = "sha256:40ea0a2849abd804a40a3e388627ae5327918a8656287022610150fd453a2242", size = 154221, upload-time = "2025-07-01T03:04:07.056Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/aa/a2f02e67736a2bf57acefb3a1a342005586f1be8d7b2fb37ca5f3d4f3049/mlx_lm-0.26.0.tar.gz", hash = "sha256:78980ad994baf976779cc1c34c0d55c1c6b63dffef4899d67fec240d0c443b52", size = 159064, upload-time = "2025-07-08T20:21:31.393Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/ce/3484a973943572461765977231e3b9b68876a8d7e16c3e6110b81c180a89/mlx_lm-0.25.3-py3-none-any.whl", hash = "sha256:56a84f1ae4a3581b13c84c4d8edaa6704b971b40090b725dfc3b719b522ccc2b", size = 203913, upload-time = "2025-07-01T03:04:05.928Z" }, + { url = "https://files.pythonhosted.org/packages/08/e7/d0e576397b61bf90a0bb27819443f723258acd8dd1207684fdef29243ce4/mlx_lm-0.26.0-py3-none-any.whl", hash = "sha256:b00294c26242cd50db4b6e3ec3a2baf1cfdf8ca49a5e6057dce14642fabe0d21", size = 217671, upload-time = "2025-07-08T20:21:29.448Z" }, ] [[package]] @@ -496,7 +482,7 @@ wheels = [ [[package]] name = "openai" -version = "1.93.0" +version = "1.96.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -508,9 +494,9 @@ dependencies = [ { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/d7/e91c6a9cf71726420cddf539852ee4c29176ebb716a702d9118d0409fd8e/openai-1.93.0.tar.gz", hash = "sha256:988f31ade95e1ff0585af11cc5a64510225e4f5cd392698c675d0a9265b8e337", size = 486573, upload-time = "2025-06-27T21:21:39.421Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/b5/18fd5e1b6b6c7dca52d60307b3637f9e9e3206a8041a9c8028985dbc6260/openai-1.96.1.tar.gz", hash = "sha256:6d505b5cc550e036bfa3fe99d6cff565b11491d12378d4c353f92ef72b0a408a", size = 489065, upload-time = "2025-07-15T21:39:37.215Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/46/a10d9df4673df56f71201d129ba1cb19eaff3366d08c8664d61a7df52e65/openai-1.93.0-py3-none-any.whl", hash = "sha256:3d746fe5498f0dd72e0d9ab706f26c91c0f646bf7459e5629af8ba7c9dbdf090", size = 755038, upload-time = "2025-06-27T21:21:37.532Z" }, + { url = "https://files.pythonhosted.org/packages/4f/57/325bbdbdc27b47309be35cb4e0eb8980b0c1bc997194c797c3691d88ae41/openai-1.96.1-py3-none-any.whl", hash = "sha256:0afaab2019bae8e145e7a1baf6953167084f019dd15042c65edd117398c1eb1c", size = 757454, upload-time = "2025-07-15T21:39:34.517Z" }, ] [[package]] @@ -617,14 +603,14 @@ wheels = [ [[package]] name = "pytest-asyncio" -version = "1.0.0" +version = "1.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960, upload-time = "2025-05-26T04:54:40.484Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" }, + { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, ] [[package]] @@ -693,24 +679,43 @@ wheels = [ [[package]] name = "ruff" -version = "0.12.2" +version = "0.12.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/3d/d9a195676f25d00dbfcf3cf95fdd4c685c497fcfa7e862a44ac5e4e96480/ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e", size = 4432239, upload-time = "2025-07-03T16:40:19.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/b6/2098d0126d2d3318fd5bec3ad40d06c25d377d95749f7a0c5af17129b3b1/ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be", size = 10369761, upload-time = "2025-07-03T16:39:38.847Z" }, - { url = "https://files.pythonhosted.org/packages/b1/4b/5da0142033dbe155dc598cfb99262d8ee2449d76920ea92c4eeb9547c208/ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e", size = 11155659, upload-time = "2025-07-03T16:39:42.294Z" }, - { url = "https://files.pythonhosted.org/packages/3e/21/967b82550a503d7c5c5c127d11c935344b35e8c521f52915fc858fb3e473/ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc", size = 10537769, upload-time = "2025-07-03T16:39:44.75Z" }, - { url = "https://files.pythonhosted.org/packages/33/91/00cff7102e2ec71a4890fb7ba1803f2cdb122d82787c7d7cf8041fe8cbc1/ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922", size = 10717602, upload-time = "2025-07-03T16:39:47.652Z" }, - { url = "https://files.pythonhosted.org/packages/9b/eb/928814daec4e1ba9115858adcda44a637fb9010618721937491e4e2283b8/ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b", size = 10198772, upload-time = "2025-07-03T16:39:49.641Z" }, - { url = "https://files.pythonhosted.org/packages/50/fa/f15089bc20c40f4f72334f9145dde55ab2b680e51afb3b55422effbf2fb6/ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d", size = 11845173, upload-time = "2025-07-03T16:39:52.069Z" }, - { url = "https://files.pythonhosted.org/packages/43/9f/1f6f98f39f2b9302acc161a4a2187b1e3a97634fe918a8e731e591841cf4/ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1", size = 12553002, upload-time = "2025-07-03T16:39:54.551Z" }, - { url = "https://files.pythonhosted.org/packages/d8/70/08991ac46e38ddd231c8f4fd05ef189b1b94be8883e8c0c146a025c20a19/ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4", size = 12171330, upload-time = "2025-07-03T16:39:57.55Z" }, - { url = "https://files.pythonhosted.org/packages/88/a9/5a55266fec474acfd0a1c73285f19dd22461d95a538f29bba02edd07a5d9/ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9", size = 11774717, upload-time = "2025-07-03T16:39:59.78Z" }, - { url = "https://files.pythonhosted.org/packages/87/e5/0c270e458fc73c46c0d0f7cf970bb14786e5fdb88c87b5e423a4bd65232b/ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da", size = 11646659, upload-time = "2025-07-03T16:40:01.934Z" }, - { url = "https://files.pythonhosted.org/packages/b7/b6/45ab96070c9752af37f0be364d849ed70e9ccede07675b0ec4e3ef76b63b/ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce", size = 10604012, upload-time = "2025-07-03T16:40:04.363Z" }, - { url = "https://files.pythonhosted.org/packages/86/91/26a6e6a424eb147cc7627eebae095cfa0b4b337a7c1c413c447c9ebb72fd/ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d", size = 10176799, upload-time = "2025-07-03T16:40:06.514Z" }, - { url = "https://files.pythonhosted.org/packages/f5/0c/9f344583465a61c8918a7cda604226e77b2c548daf8ef7c2bfccf2b37200/ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04", size = 11241507, upload-time = "2025-07-03T16:40:08.708Z" }, - { url = "https://files.pythonhosted.org/packages/1c/b7/99c34ded8fb5f86c0280278fa89a0066c3760edc326e935ce0b1550d315d/ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342", size = 11717609, upload-time = "2025-07-03T16:40:10.836Z" }, + { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" }, + { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" }, + { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" }, + { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" }, + { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" }, + { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" }, + { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" }, + { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" }, + { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" }, + { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" }, + { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" }, +] + +[[package]] +name = "rustworkx" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/c4/6d6ef39e57610d54c5f106dc3dece9eebce8b9d52d561ae092e3aede1b66/rustworkx-0.16.0.tar.gz", hash = "sha256:9f0dcb83f38d5ca2c3a683eb9b6951c8aec3262fbfe5141946a7ee5ba37e0bb6", size = 349524, upload-time = "2025-01-24T01:22:34.686Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/70/36f5916aee41ffe4f604ad75742eb1bb1b849fb568e010555f9d159cd93e/rustworkx-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:476a6c67b0142acd941691943750cc6737a48372304489969c2b62d30aaf4c27", size = 2141999, upload-time = "2025-01-24T01:21:50.3Z" }, + { url = "https://files.pythonhosted.org/packages/94/47/7e7c37fb73efcc87be6414b235534605c4008a4cdbd92a61db23b878eecd/rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bef2ef42870f806af93979b457e240f6dfa4f867ca33965c620f3a804409ed3a", size = 1940309, upload-time = "2025-01-24T01:21:52.053Z" }, + { url = "https://files.pythonhosted.org/packages/c6/42/a6d6b3137be55ef1d887becdf6b64b0917c7d437bd483065a88500a55603/rustworkx-0.16.0-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0db3a73bf68b3e66c08322a2fc95d3aa663d037d9b4e49c3509da4898d3529cc", size = 2195350, upload-time = "2025-01-24T01:21:53.785Z" }, + { url = "https://files.pythonhosted.org/packages/59/d2/1bc99df831c132c4b7420a85ce9150e065f4c993798f31b6a4229f238398/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f12a13d7486234fa2a84746d5e41f436bf9df43548043e7a232f48804ff8c61", size = 1971689, upload-time = "2025-01-24T17:09:26.338Z" }, + { url = "https://files.pythonhosted.org/packages/b5/3b/1125e7eb834f4408bcec3cee79947efd504c715fb7ab1876f8cd4bbca497/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89efd5c3a4653ddacc55ca39f28b261d43deec7d678f8f8fc6b76b5087f1dfea", size = 3297342, upload-time = "2025-01-24T03:18:48.885Z" }, + { url = "https://files.pythonhosted.org/packages/4f/e2/e21187b255c6211d71db0d08a44fc16771038b2af41712d66c408d9bec16/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec0c12aac8c54910ace20ac6ada4b890cd39f95f69100514715f8ad7af9041e4", size = 2110107, upload-time = "2025-01-24T01:21:58.884Z" }, + { url = "https://files.pythonhosted.org/packages/3c/79/e3fcff21f31253ea85ef196bf2fcabad7802b11468f7d3a5d592cd0ac789/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d650e39fc1a1534335f7517358ebfc3478bb235428463cfcd7c5750d50377b33", size = 2007544, upload-time = "2025-01-26T04:16:53.807Z" }, + { url = "https://files.pythonhosted.org/packages/67/04/741ed09c2b0dc0f360f85270c1179ed433785372ac9ab6ab26d3dd3ae02d/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:293180b83509ee9bff4c3af7ccc1024f6528d61b65d0cb7320bd31924f10cb71", size = 2172787, upload-time = "2025-01-24T01:22:01.282Z" }, ] [[package]] @@ -733,12 +738,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/ad/2b113098e69c985a3d8fbda4b902778eae4a35b7d5188859b4a63d30c161/safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04", size = 643147, upload-time = "2025-02-26T09:15:11.185Z" }, ] -[[package]] -name = "sentencepiece" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c9/d2/b9c7ca067c26d8ff085d252c89b5f69609ca93fb85a00ede95f4857865d4/sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843", size = 2632106, upload-time = "2024-02-19T17:06:47.428Z" } - [[package]] name = "sniffio" version = "1.3.1" @@ -748,6 +747,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "starlette" +version = "0.47.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072, upload-time = "2025-06-21T04:03:17.337Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" }, +] + [[package]] name = "tokenizers" version = "0.21.2" @@ -782,7 +793,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.53.1" +version = "4.53.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -796,64 +807,9 @@ dependencies = [ { name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/2c/68a0024c311db41bb92d4ec17d22e90b7406a4d28aa18d87662f2bbebcd9/transformers-4.53.1.tar.gz", hash = "sha256:da5a9f66ad480bc2a7f75bc32eaf735fd20ac56af4325ca4ce994021ceb37710", size = 9192189, upload-time = "2025-07-04T08:28:40.571Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/67/80f51466ec447028fd84469b208eb742533ce06cc8fad2e3181380199e5c/transformers-4.53.2.tar.gz", hash = "sha256:6c3ed95edfb1cba71c4245758f1b4878c93bf8cde77d076307dacb2cbbd72be2", size = 9201233, upload-time = "2025-07-11T12:39:08.742Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/10/8cef2288810a3210659eb3a20711e8387cc35a881a7762ae387806e2d651/transformers-4.53.1-py3-none-any.whl", hash = "sha256:c84f3c3e41c71fdf2c60c8a893e1cd31191b0cb463385f4c276302d2052d837b", size = 10825681, upload-time = "2025-07-04T08:28:37.318Z" }, -] - -[package.optional-dependencies] -sentencepiece = [ - { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] - -[[package]] -name = "rustworkx" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a5/c4/6d6ef39e57610d54c5f106dc3dece9eebce8b9d52d561ae092e3aede1b66/rustworkx-0.16.0.tar.gz", hash = "sha256:9f0dcb83f38d5ca2c3a683eb9b6951c8aec3262fbfe5141946a7ee5ba37e0bb6", size = 349524, upload-time = "2025-01-24T01:22:34.686Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/70/36f5916aee41ffe4f604ad75742eb1bb1b849fb568e010555f9d159cd93e/rustworkx-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:476a6c67b0142acd941691943750cc6737a48372304489969c2b62d30aaf4c27", size = 2141999, upload-time = "2025-01-24T01:21:50.3Z" }, - { url = "https://files.pythonhosted.org/packages/94/47/7e7c37fb73efcc87be6414b235534605c4008a4cdbd92a61db23b878eecd/rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bef2ef42870f806af93979b457e240f6dfa4f867ca33965c620f3a804409ed3a", size = 1940309, upload-time = "2025-01-24T01:21:52.053Z" }, - { url = "https://files.pythonhosted.org/packages/c6/42/a6d6b3137be55ef1d887becdf6b64b0917c7d437bd483065a88500a55603/rustworkx-0.16.0-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0db3a73bf68b3e66c08322a2fc95d3aa663d037d9b4e49c3509da4898d3529cc", size = 2195350, upload-time = "2025-01-24T01:21:53.785Z" }, - { url = "https://files.pythonhosted.org/packages/59/d2/1bc99df831c132c4b7420a85ce9150e065f4c993798f31b6a4229f238398/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f12a13d7486234fa2a84746d5e41f436bf9df43548043e7a232f48804ff8c61", size = 1971689, upload-time = "2025-01-24T17:09:26.338Z" }, - { url = "https://files.pythonhosted.org/packages/b5/3b/1125e7eb834f4408bcec3cee79947efd504c715fb7ab1876f8cd4bbca497/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89efd5c3a4653ddacc55ca39f28b261d43deec7d678f8f8fc6b76b5087f1dfea", size = 3297342, upload-time = "2025-01-24T03:18:48.885Z" }, - { url = "https://files.pythonhosted.org/packages/4f/e2/e21187b255c6211d71db0d08a44fc16771038b2af41712d66c408d9bec16/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec0c12aac8c54910ace20ac6ada4b890cd39f95f69100514715f8ad7af9041e4", size = 2110107, upload-time = "2025-01-24T01:21:58.884Z" }, - { url = "https://files.pythonhosted.org/packages/3c/79/e3fcff21f31253ea85ef196bf2fcabad7802b11468f7d3a5d592cd0ac789/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d650e39fc1a1534335f7517358ebfc3478bb235428463cfcd7c5750d50377b33", size = 2007544, upload-time = "2025-01-26T04:16:53.807Z" }, - { url = "https://files.pythonhosted.org/packages/67/04/741ed09c2b0dc0f360f85270c1179ed433785372ac9ab6ab26d3dd3ae02d/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:293180b83509ee9bff4c3af7ccc1024f6528d61b65d0cb7320bd31924f10cb71", size = 2172787, upload-time = "2025-01-24T01:22:01.282Z" }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - -[[package]] -name = "starlette" -version = "0.46.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846, upload-time = "2025-04-13T13:56:17.942Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037, upload-time = "2025-04-13T13:56:16.21Z" }, -] - -[[package]] -name = "tqdm" -version = "4.67.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, + { url = "https://files.pythonhosted.org/packages/96/88/beb33a79a382fcd2aed0be5222bdc47f41e4bfe7aaa90ae1374f1d8ea2af/transformers-4.53.2-py3-none-any.whl", hash = "sha256:db8f4819bb34f000029c73c3c557e7d06fc1b8e612ec142eecdae3947a9c78bf", size = 10826609, upload-time = "2025-07-11T12:39:05.461Z" }, ] [[package]] diff --git a/worker/runner/communication.py b/worker/runner/communication.py index 2b5cee12..5491f171 100644 --- a/worker/runner/communication.py +++ b/worker/runner/communication.py @@ -14,13 +14,19 @@ from shared.types.worker.commands_runner import ( ### Utils - MESSAGE TO RUNNER -async def supervisor_write_message(proc: asyncio.subprocess.Process, message: RunnerMessage) -> None: - assert proc.stdin is not None, "proc.stdin should not be None when created with stdin=PIPE" - - encoded: bytes = message.model_dump_json().encode('utf-8') + b'\n' + +async def supervisor_write_message( + proc: asyncio.subprocess.Process, message: RunnerMessage +) -> None: + assert proc.stdin is not None, ( + "proc.stdin should not be None when created with stdin=PIPE" + ) + + encoded: bytes = message.model_dump_json().encode("utf-8") + b"\n" proc.stdin.write(encoded) await proc.stdin.drain() + async def runner_read_message() -> RunnerMessage: loop = asyncio.get_running_loop() @@ -34,17 +40,24 @@ async def runner_read_message() -> RunnerMessage: except Exception as e: raise ValueError(f"Error validating message: {line}") from e + ### Utils - RESPONSE FROM RUNNER + def runner_write_response(obj: RunnerResponse) -> None: - encoded: bytes = obj.model_dump_json().encode('utf-8') + b'\n' + encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" _ = sys.stdout.buffer.write(encoded) _ = sys.stdout.buffer.flush() -async def supervisor_read_response(proc: asyncio.subprocess.Process) -> RunnerResponse | None: - assert proc.stdout is not None, "proc.stdout should not be None when created with stdout=PIPE" + +async def supervisor_read_response( + proc: asyncio.subprocess.Process, +) -> RunnerResponse | None: + assert proc.stdout is not None, ( + "proc.stdout should not be None when created with stdout=PIPE" + ) line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=10) - line: str = line_bytes.decode('utf-8').strip() + line: str = line_bytes.decode("utf-8").strip() if not line: raise EOFError("No more data to read") @@ -57,6 +70,7 @@ async def supervisor_read_response(proc: asyncio.subprocess.Process) -> RunnerRe ### Utils - Runner Prints + def runner_print(text: str) -> None: obj = PrintResponse( type=RunnerResponseType.PrintResponse, @@ -65,11 +79,12 @@ def runner_print(text: str) -> None: runner_write_response(obj) + def runner_write_error(error: Exception) -> None: error_response: ErrorResponse = ErrorResponse( - type=RunnerResponseType.ErrorResponse, - error_type=type(error).__name__, - error_message=str(error), - traceback=traceback.format_exc(), + type=RunnerResponseType.ErrorResponse, + error_type=type(error).__name__, + error_message=str(error), + traceback=traceback.format_exc(), ) - runner_write_response(error_response) \ No newline at end of file + runner_write_response(error_response) diff --git a/worker/runner/runner.py b/worker/runner/runner.py index b7a7f852..3e4d76b3 100644 --- a/worker/runner/runner.py +++ b/worker/runner/runner.py @@ -11,7 +11,7 @@ import mlx.nn as nn from mlx_lm.generate import stream_generate # type: ignore from mlx_lm.tokenizer_utils import TokenizerWrapper -from shared.mlx.utils_mlx import apply_chat_template, initialize_mlx +from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx from shared.openai import FinishReason from shared.types.tasks.common import ( TaskData, @@ -58,13 +58,15 @@ async def _mlx_generate( response = GenerationResponse( text=generation_response.text, token=generation_response.token, - finish_reason=cast(FinishReason | None, generation_response.finish_reason), # has to be considered as a FinishReason instead of a str. + finish_reason=cast( + FinishReason | None, generation_response.finish_reason + ), # has to be considered as a FinishReason instead of a str. ) _ = loop.call_soon_threadsafe(queue.put_nowait, response) except Exception as e: _ = loop.call_soon_threadsafe(queue.put_nowait, e) finally: - _ = loop.call_soon_threadsafe(queue.put_nowait, sentinel) + _ = loop.call_soon_threadsafe(queue.put_nowait, sentinel) # Currently we support chat-completion tasks only. task_data = task.task_data @@ -91,15 +93,16 @@ async def _mlx_generate( if isinstance(item, Exception): raise item - - assert isinstance(item, GenerationResponse) # constrain datatype + + assert isinstance(item, GenerationResponse) # constrain datatype yield item assert future.done() + async def main(): try: - runner_print('hello from the runner') + runner_print("hello from the runner") # Get setup info from worker init_message: RunnerMessage = await runner_read_message() @@ -107,10 +110,12 @@ async def main(): model_shard_meta: ShardMeta = setup_message.model_shard_meta hosts: list[Host] = setup_message.hosts - mlx_executor: ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + mlx_executor: ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor( + max_workers=1 + ) loop: AbstractEventLoop = asyncio.get_running_loop() - runner_print(f'got here; {model_shard_meta.model_path}') + runner_print(f"got here; {model_shard_meta.model_path}") model, tokenizer, sampler = await loop.run_in_executor( mlx_executor, @@ -137,7 +142,7 @@ async def main(): task=task_data, ): runner_write_response(generation_response) - + runner_write_response(FinishedResponse()) case ExitMessage(): break @@ -147,5 +152,6 @@ async def main(): except Exception as e: runner_write_error(e) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/worker/runner/runner_supervisor.py b/worker/runner/runner_supervisor.py index 2b85d82b..ba15bf4a 100644 --- a/worker/runner/runner_supervisor.py +++ b/worker/runner/runner_supervisor.py @@ -2,7 +2,7 @@ import asyncio import contextlib import sys from collections.abc import AsyncGenerator -from typing import Callable +from typing import Any, Callable from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData from shared.types.tasks.common import Task, TaskStatusType, TaskType @@ -17,8 +17,7 @@ from shared.types.worker.commands_runner import ( SetupMessage, ) from shared.types.worker.mlx import Host -from shared.types.worker.runners import RunnerError -from shared.types.worker.shards import ShardMeta +from shared.types.worker.shards import ShardMetadata from worker.runner.communication import ( supervisor_read_response, supervisor_write_message, @@ -31,25 +30,27 @@ class RunnerSupervisor: RunnerSupervisor manages the lifecycle of a runner subprocess for model inference. Use the class method `create` to properly initialize an instance. """ - + def __init__( self, - model_shard_meta: ShardMeta, + model_shard_meta: ShardMetadata[Any], hosts: list[Host], runner_process: asyncio.subprocess.Process, ): """Private constructor. Use RunnerSupervisor.create() instead.""" - self.model_shard_meta: ShardMeta = model_shard_meta + self.model_shard_meta: ShardMetadata[Any] = model_shard_meta self.hosts: list[Host] = hosts self.runner_process: asyncio.subprocess.Process = runner_process self.running: bool = True - self.running_task: asyncio.Task[None] = asyncio.create_task(self._watch_runner()) + self.running_task: asyncio.Task[None] = asyncio.create_task( + self._watch_runner() + ) @classmethod async def create( cls, - model_shard_meta: ShardMeta, + model_shard_meta: ShardMetadata[Any], hosts: list[Host], ) -> "RunnerSupervisor": """ @@ -57,12 +58,14 @@ class RunnerSupervisor: The .create() classmethod pattern is used to ensure the constructor is asynchronous. """ cmd: list[str] = get_runner_command() - - runner_process: asyncio.subprocess.Process = await asyncio.create_subprocess_exec( - *cmd, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=sys.stderr, + + runner_process: asyncio.subprocess.Process = ( + await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr, + ) ) await supervisor_write_message( @@ -91,7 +94,9 @@ class RunnerSupervisor: if self.runner_process.stdout is not None: while True: try: - line = await asyncio.wait_for(self.runner_process.stdout.readline(), timeout=0.01) + line = await asyncio.wait_for( + self.runner_process.stdout.readline(), timeout=0.01 + ) if not line: break print(f"Remaining stdout: {line.decode('utf-8').strip()}") @@ -100,7 +105,9 @@ class RunnerSupervisor: try: # Give the process a moment to exit gracefully - await supervisor_write_message(proc=self.runner_process, message=ExitMessage()) + await supervisor_write_message( + proc=self.runner_process, message=ExitMessage() + ) _ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1) except asyncio.TimeoutError: print("Runner process did not terminate, killing...") @@ -114,7 +121,9 @@ class RunnerSupervisor: def __del__(self) -> None: if not self.running: - print('Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process.') + print( + "Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process." + ) with contextlib.suppress(ProcessLookupError): self.runner_process.kill() @@ -150,12 +159,16 @@ class RunnerSupervisor: ) while True: - line: RunnerResponse | None = await supervisor_read_response(self.runner_process) + line: RunnerResponse | None = await supervisor_read_response( + self.runner_process + ) if line is None: continue else: match line: - case GenerationResponse(text=text, token=token, finish_reason=finish_reason): + case GenerationResponse( + text=text, token=token, finish_reason=finish_reason + ): yield TokenChunk( task_id=task.task_id, idx=token, @@ -169,7 +182,11 @@ class RunnerSupervisor: case FinishedResponse(): break case PrintResponse(text=text): - print(f'runner printed: {text}') - case ErrorResponse(error_type=error_type, error_message=error_message, traceback=traceback): + print(f"runner printed: {text}") + case ErrorResponse( + error_type=error_type, + error_message=error_message, + traceback=traceback, + ): await self.astop() - raise RunnerError(error_type, error_message, traceback or "") + raise Exception(error_type, error_message, traceback or "") diff --git a/worker/runner/utils.py b/worker/runner/utils.py index 0f252633..41b168ba 100644 --- a/worker/runner/utils.py +++ b/worker/runner/utils.py @@ -3,6 +3,4 @@ import sys def get_runner_command() -> list[str]: python = sys.executable - return [ - python, '-m', 'worker.runner.runner' - ] \ No newline at end of file + return [python, "-m", "worker.runner.runner"] diff --git a/worker/runner/conftest.py b/worker/tests/conftest.py similarity index 54% rename from worker/runner/conftest.py rename to worker/tests/conftest.py index 57c5d8f1..a631cb4c 100644 --- a/worker/runner/conftest.py +++ b/worker/tests/conftest.py @@ -3,48 +3,69 @@ from pathlib import Path from typing import Callable, cast import pytest +from openai.types.chat import ChatCompletionUserMessageParam +from openai.types.chat.completion_create_params import ( + CompletionCreateParamsNonStreaming, + CompletionCreateParamsStreaming, +) +from pydantic import TypeAdapter from shared.types.models.common import ModelId from shared.types.tasks.common import ( - ChatCompletionMessage, - ChatCompletionParams, ChatCompletionStreamingTask, - PendingTaskStatus, Task, TaskArtifact, TaskId, TaskState, - TaskStatusIncompleteType, + TaskStatusOtherType, TaskStatusType, TaskType, ) from shared.types.worker.common import InstanceId from shared.types.worker.mlx import Host -from shared.types.worker.shards import PipelineShardMeta +from shared.types.worker.shards import PipelineShardMetadata + +CompletionCreateParamsStreamingAdapter = TypeAdapter(CompletionCreateParamsStreaming) +CompletionCreateParamsNonStreamingAdapter = TypeAdapter( + CompletionCreateParamsNonStreaming +) # Concrete TaskArtifact implementation for pending streaming tasks -class PendingStreamingTaskArtifact(TaskArtifact[TaskType.ChatCompletionStreaming, TaskStatusIncompleteType.Pending]): +class PendingStreamingTaskArtifact( + TaskArtifact[TaskType.ChatCompletionStreaming, TaskStatusOtherType.Pending] +): pass + @pytest.fixture def pipeline_shard_meta(): - def _pipeline_shard_meta(num_nodes: int = 1, device_rank: int = 0) -> PipelineShardMeta: + def _pipeline_shard_meta( + num_nodes: int = 1, device_rank: int = 0 + ) -> PipelineShardMetadata: total_layers = 16 layers_per_node = total_layers // num_nodes start_layer = device_rank * layers_per_node - end_layer = start_layer + layers_per_node if device_rank < num_nodes - 1 else total_layers - - return PipelineShardMeta( + end_layer = ( + start_layer + layers_per_node + if device_rank < num_nodes - 1 + else total_layers + ) + + return PipelineShardMetadata( device_rank=device_rank, model_id=ModelId(uuid=uuid.uuid4()), - model_path=Path("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit/").expanduser(), + model_path=Path( + "~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit/" + ).expanduser(), start_layer=start_layer, end_layer=end_layer, world_size=num_nodes, ) + return _pipeline_shard_meta + @pytest.fixture def hosts(): def _hosts(count: int, offset: int = 0) -> list[Host]: @@ -55,51 +76,57 @@ def hosts(): ) for i in range(count) ] + return _hosts + @pytest.fixture def hosts_one(hosts: Callable[[int], list[Host]]): return hosts(1) + @pytest.fixture def hosts_two(hosts: Callable[[int], list[Host]]): return hosts(2) + @pytest.fixture def user_message(): """Override this fixture in tests to customize the message""" return "Hello, how are you?" + @pytest.fixture def chat_completion_params(user_message: str): """Creates ChatCompletionParams with the given message""" - return ChatCompletionParams( + return CompletionCreateParamsStreaming( model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", - content=user_message - ) - ], - stream=True + messages=[ChatCompletionUserMessageParam(role="user", content=user_message)], + stream=True, ) + @pytest.fixture -def chat_completion_streaming_task_data(chat_completion_params: ChatCompletionParams): +def chat_completion_streaming_task_data( + chat_completion_params: CompletionCreateParamsStreaming, +): """Creates ChatCompletionStreamingTask from params""" - return ChatCompletionStreamingTask( - task_data=chat_completion_params - ) + return ChatCompletionStreamingTask(task_data=chat_completion_params) + @pytest.fixture -def streaming_task(chat_completion_streaming_task_data: ChatCompletionStreamingTask) -> Task[TaskType, TaskStatusType]: +def streaming_task( + chat_completion_streaming_task_data: CompletionCreateParamsStreaming, +) -> Task[TaskType, TaskStatusType]: """Creates the final Task object""" task = Task( task_id=TaskId(), task_type=TaskType.ChatCompletionStreaming, - task_data=chat_completion_streaming_task_data, + task_params=ChatCompletionStreamingTask( + task_data=chat_completion_streaming_task_data + ), task_state=TaskState( - task_status=PendingTaskStatus(), + task_status=TaskStatusOtherType.Pending, task_artifact=PendingStreamingTaskArtifact(), ), on_instance=InstanceId(), diff --git a/worker/runner/test_serdes.py b/worker/tests/test_serdes.py similarity index 53% rename from worker/runner/test_serdes.py rename to worker/tests/test_serdes.py index fe85da0e..8119aa4a 100644 --- a/worker/runner/test_serdes.py +++ b/worker/tests/test_serdes.py @@ -2,31 +2,41 @@ from typing import Callable, Literal, TypeVar from pydantic import BaseModel, TypeAdapter -from shared.types.tasks.common import Task, TaskStatusIncompleteType, TaskType +from shared.types.tasks.common import Task, TaskStatusOtherType, TaskType from shared.types.worker.commands_runner import ( ChatTaskMessage, RunnerMessageTypeAdapter, SetupMessage, ) from shared.types.worker.mlx import Host -from shared.types.worker.shards import PipelineShardMeta +from shared.types.worker.shards import PipelineShardMetadata + +T = TypeVar("T", bound=BaseModel) -T = TypeVar('T', bound=BaseModel) def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]): - encoded: bytes = obj.model_dump_json().encode('utf-8') + b'\n' + encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" decoded: T = typeadapter.validate_json(encoded) - assert decoded == obj, f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}" + assert decoded == obj, ( + f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}" + ) -def test_supervisor_setup_message_serdes(pipeline_shard_meta: Callable[..., PipelineShardMeta], hosts: Callable[..., list[Host]]): + +def test_supervisor_setup_message_serdes( + pipeline_shard_meta: Callable[..., PipelineShardMetadata], + hosts: Callable[..., list[Host]], +): setup_message = SetupMessage( model_shard_meta=pipeline_shard_meta(1, 0), hosts=hosts(1), ) assert_equal_serdes(setup_message, RunnerMessageTypeAdapter) -def test_supervisor_task_message_serdes(streaming_task: Task[TaskType, Literal[TaskStatusIncompleteType.Pending]]): + +def test_supervisor_task_message_serdes( + streaming_task: Task[TaskType, Literal[TaskStatusOtherType.Pending]], +): task_message = ChatTaskMessage( task=streaming_task.task_data, ) diff --git a/worker/runner/test_supervisor.py b/worker/tests/test_supervisor.py similarity index 88% rename from worker/runner/test_supervisor.py rename to worker/tests/test_supervisor.py index 46a93883..3c17099d 100644 --- a/worker/runner/test_supervisor.py +++ b/worker/tests/test_supervisor.py @@ -34,7 +34,7 @@ async def test_supervisor_single_node_response( try: full_response = "" stop_reason: FinishReason | None = None - + async for chunk in supervisor.stream_response(task=streaming_task): if isinstance(chunk, TokenChunk): full_response += chunk.chunk_data.text @@ -42,12 +42,15 @@ async def test_supervisor_single_node_response( stop_reason = chunk.chunk_data.finish_reason # Case-insensitive check for Paris in the response - assert "paris" in full_response.lower(), f"Expected 'Paris' in response, but got: {full_response}" - assert stop_reason == 'stop' - + assert "paris" in full_response.lower(), ( + f"Expected 'Paris' in response, but got: {full_response}" + ) + assert stop_reason == "stop" + finally: await supervisor.astop() + @pytest.mark.asyncio async def test_supervisor_two_node_response( pipeline_shard_meta: Callable[..., PipelineShardMeta], @@ -70,33 +73,38 @@ async def test_supervisor_two_node_response( try: full_response_0 = "" full_response_1 = "" - + async def collect_response_0(): nonlocal full_response_0 async for chunk in supervisor_0.stream_response(task=streaming_task): if isinstance(chunk, TokenChunk): full_response_0 += chunk.chunk_data.text - + async def collect_response_1(): nonlocal full_response_1 async for chunk in supervisor_1.stream_response(task=streaming_task): if isinstance(chunk, TokenChunk): full_response_1 += chunk.chunk_data.text - + # Run both stream responses simultaneously _ = await asyncio.gather(collect_response_0(), collect_response_1()) print(f"full_response_0: {full_response_0}") print(f"full_response_1: {full_response_1}") - + # Case-insensitive check for Paris in both responses - assert "paris" in full_response_0.lower(), f"Expected 'Paris' in response, but got: {full_response_0}" - assert "paris" in full_response_1.lower(), f"Expected 'Paris' in response, but got: {full_response_1}" - + assert "paris" in full_response_0.lower(), ( + f"Expected 'Paris' in response, but got: {full_response_0}" + ) + assert "paris" in full_response_1.lower(), ( + f"Expected 'Paris' in response, but got: {full_response_1}" + ) + finally: await supervisor_0.astop() await supervisor_1.astop() + @pytest.mark.asyncio async def test_supervisor_early_stopping( pipeline_shard_meta: Callable[..., PipelineShardMeta], @@ -115,8 +123,10 @@ async def test_supervisor_early_stopping( try: streaming_task.task_data.task_data.max_tokens = max_tokens - streaming_task.task_data.task_data.messages[0].content = "Please count from 1 to 100" - + streaming_task.task_data.task_data.messages[ + 0 + ].content = "Please count from 1 to 100" + full_response = "" count = 0 stop_reason: FinishReason | None = None @@ -127,14 +137,14 @@ async def test_supervisor_early_stopping( count += 1 if chunk.chunk_data.finish_reason: stop_reason = chunk.chunk_data.finish_reason - + print(f"full_response: {full_response}") assert count == max_tokens + 1 - assert '7' in full_response.lower() - assert '99' not in full_response.lower() + assert "7" in full_response.lower() + assert "99" not in full_response.lower() - assert stop_reason == 'length' + assert stop_reason == "length" finally: await supervisor.astop()