From 370e6de07859f71db18e72bc72046910ce2bfb4a Mon Sep 17 00:00:00 2001 From: Evan Date: Fri, 23 Jan 2026 23:56:13 +0000 Subject: [PATCH] api cancellation closing the http request to the api now - sends a cancellation from the api - writes that canellation in the master - worker plans off the cancellation - runner observes that cancellation after every generation step (+1 communication per token) - cancellation happens synchronously to prevent gpu locks --- MISSED_THINGS.md | 22 +++--- src/exo/master/adapters/chat_completions.py | 7 +- src/exo/master/api.py | 41 ++++++---- src/exo/master/main.py | 27 +++++-- src/exo/master/placement.py | 21 ++++- src/exo/master/tests/test_placement.py | 6 +- src/exo/shared/types/commands.py | 5 ++ src/exo/shared/types/tasks.py | 7 ++ src/exo/utils/channels.py | 4 +- .../worker/engines/mlx/generator/generate.py | 7 -- src/exo/worker/engines/mlx/utils_mlx.py | 46 +++++------ src/exo/worker/main.py | 26 ++++--- src/exo/worker/plan.py | 37 ++++++--- src/exo/worker/runner/bootstrap.py | 5 +- src/exo/worker/runner/runner.py | 76 ++++++++++++++----- src/exo/worker/runner/runner_supervisor.py | 27 ++++--- .../test_runner/test_event_ordering.py | 7 +- tests/run_exo_on.sh | 2 +- 18 files changed, 247 insertions(+), 126 deletions(-) diff --git a/MISSED_THINGS.md b/MISSED_THINGS.md index 8cb11130..76083edb 100644 --- a/MISSED_THINGS.md +++ b/MISSED_THINGS.md @@ -5,21 +5,21 @@ [X] Fetching download status of all models on start [X] Deduplication of tasks in plan_step. [X] resolve_allow_patterns should just be wildcard now. -[] no mx_barrier in genreate.py mlx_generate at the end. +[X] no mx_barrier in genreate.py mlx_generate at the end. [] cache assertion not needed in auto_parallel.py PipelineLastLayer. -[] GPTOSS support dropped in auto_parallel.py. -[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py. -[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py. -[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py. +[X] GPTOSS support dropped in auto_parallel.py. +[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py. +[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py. +[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py. [] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py. [X] KV_CACHE_BITS should be None to disable quantized KV cache. -[] Dropped _set_nofile_limit in utils_mlx.py. -[] We have group optional in load_mlx_items in utils_mlx.py. -[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py. -[] Dropped model.make_cache in make_kv_cache in utils_mlx.py. +[X] Dropped _set_nofile_limit in utils_mlx.py. +[X] We have group optional in load_mlx_items in utils_mlx.py. +[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py. +[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py. [X] We put cache limit back in utils_mlx.py. -[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess? -[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.) +[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess? +[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.) [] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0. [X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions. [X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work). diff --git a/src/exo/master/adapters/chat_completions.py b/src/exo/master/adapters/chat_completions.py index 3e013079..c9228e25 100644 --- a/src/exo/master/adapters/chat_completions.py +++ b/src/exo/master/adapters/chat_completions.py @@ -176,7 +176,7 @@ async def generate_chat_stream( async def collect_chat_response( command_id: CommandId, chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None], -) -> ChatCompletionResponse: +) -> AsyncGenerator[str]: """Collect all token chunks and return a single ChatCompletionResponse.""" text_parts: list[str] = [] tool_calls: list[ToolCall] = [] @@ -223,7 +223,7 @@ async def collect_chat_response( combined_text = "".join(text_parts) assert model is not None - return ChatCompletionResponse( + yield ChatCompletionResponse( id=command_id, created=int(time.time()), model=model, @@ -241,4 +241,5 @@ async def collect_chat_response( finish_reason=finish_reason, ) ], - ) + ).model_dump_json() + return diff --git a/src/exo/master/api.py b/src/exo/master/api.py index 935c41ea..bfc0832b 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -123,6 +123,7 @@ from exo.shared.types.commands import ( PlaceInstance, SendInputChunk, StartDownload, + TaskCancelled, TaskFinished, TextGeneration, ) @@ -529,16 +530,14 @@ class API: break except anyio.get_cancelled_exc_class(): - # TODO: TaskCancelled - """ - self.command_sender.send_nowait( - ForwarderCommand(origin=self.node_id, command=command) - ) - """ + command = TaskCancelled(cancelled_command_id=command_id) + with anyio.CancelScope(shield=True): + await self.command_sender.send( + ForwarderCommand(origin=self.node_id, command=command) + ) raise finally: - command = TaskFinished(finished_command_id=command_id) - await self._send(command) + await self._send(TaskFinished(finished_command_id=command_id)) if command_id in self._text_generation_queues: del self._text_generation_queues[command_id] @@ -633,11 +632,14 @@ class API: "X-Accel-Buffering": "no", }, ) - - return await collect_chat_response( - command.command_id, - self._token_chunk_stream(command.command_id), - ) + else: + return StreamingResponse( + collect_chat_response( + command.command_id, + self._token_chunk_stream(command.command_id), + ), + media_type="application/json", + ) async def bench_chat_completions( self, payload: BenchChatCompletionRequest @@ -653,8 +655,7 @@ class API: command = TextGeneration(task_params=task_params) await self._send(command) - response = await self._collect_text_generation_with_stats(command.command_id) - return response + return await self._collect_text_generation_with_stats(command.command_id) async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId: """Validate a text model exists and return the resolved model ID. @@ -856,6 +857,11 @@ class API: del image_metadata[key] except anyio.get_cancelled_exc_class(): + command = TaskCancelled(cancelled_command_id=command_id) + with anyio.CancelScope(shield=True): + await self.command_sender.send( + ForwarderCommand(origin=self.node_id, command=command) + ) raise finally: await self._send(TaskFinished(finished_command_id=command_id)) @@ -937,6 +943,11 @@ class API: return (images, stats if capture_stats else None) except anyio.get_cancelled_exc_class(): + command = TaskCancelled(cancelled_command_id=command_id) + with anyio.CancelScope(shield=True): + await self.command_sender.send( + ForwarderCommand(origin=self.node_id, command=command) + ) raise finally: await self._send(TaskFinished(finished_command_id=command_id)) diff --git a/src/exo/master/main.py b/src/exo/master/main.py index ea8c5387..e596c855 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -21,6 +21,7 @@ from exo.shared.types.commands import ( PlaceInstance, RequestEventLog, SendInputChunk, + TaskCancelled, TaskFinished, TestCommand, TextGeneration, @@ -36,6 +37,7 @@ from exo.shared.types.events import ( NodeTimedOut, TaskCreated, TaskDeleted, + TaskStatusUpdated, TraceEventData, TracesCollected, TracesMerged, @@ -278,7 +280,7 @@ class Master: case DeleteInstance(): placement = delete_instance(command, self.state.instances) transition_events = get_transition_events( - self.state.instances, placement + self.state.instances, placement, self.state.tasks ) generated_events.extend(transition_events) case PlaceInstance(): @@ -290,7 +292,7 @@ class Master: self.state.node_network, ) transition_events = get_transition_events( - self.state.instances, placement + self.state.instances, placement, self.state.tasks ) generated_events.extend(transition_events) case CreateInstance(): @@ -300,7 +302,7 @@ class Master: self.state.instances, ) transition_events = get_transition_events( - self.state.instances, placement + self.state.instances, placement, self.state.tasks ) generated_events.extend(transition_events) case SendInputChunk(chunk=chunk): @@ -310,6 +312,18 @@ class Master: chunk=chunk, ) ) + case TaskCancelled(): + if ( + task_id := self.command_task_mapping.get( + command.cancelled_command_id + ) + ) is not None: + generated_events.append( + TaskStatusUpdated( + task_status=TaskStatus.Cancelled, + task_id=task_id, + ) + ) case TaskFinished(): generated_events.append( TaskDeleted( @@ -318,10 +332,9 @@ class Master: ] ) ) - if command.finished_command_id in self.command_task_mapping: - del self.command_task_mapping[ - command.finished_command_id - ] + self.command_task_mapping.pop( + command.finished_command_id, None + ) case RequestEventLog(): # We should just be able to send everything, since other buffers will ignore old messages for i in range(command.since_idx, len(self._event_log)): diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index e37f533f..627d4143 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -20,9 +20,15 @@ from exo.shared.types.commands import ( PlaceInstance, ) from exo.shared.types.common import NodeId -from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted +from exo.shared.types.events import ( + Event, + InstanceCreated, + InstanceDeleted, + TaskStatusUpdated, +) from exo.shared.types.memory import Memory from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo +from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.instances import ( Instance, InstanceId, @@ -180,6 +186,7 @@ def delete_instance( def get_transition_events( current_instances: Mapping[InstanceId, Instance], target_instances: Mapping[InstanceId, Instance], + tasks: Mapping[TaskId, Task], ) -> Sequence[Event]: events: list[Event] = [] @@ -195,6 +202,18 @@ def get_transition_events( # find instances to delete for instance_id in current_instances: if instance_id not in target_instances: + for task in tasks.values(): + if task.instance_id == instance_id and task.task_status in [ + TaskStatus.Pending, + TaskStatus.Running, + ]: + events.append( + TaskStatusUpdated( + task_status=TaskStatus.Cancelled, + task_id=task.task_id, + ) + ) + events.append( InstanceDeleted( instance_id=instance_id, diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index baa1a0f8..ad5638e7 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance): target_instances = {instance_id: instance} # act - events = get_transition_events(current_instances, target_instances) + events = get_transition_events(current_instances, target_instances, {}) # assert assert len(events) == 0 @@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance): target_instances: dict[InstanceId, Instance] = {instance_id: instance} # act - events = get_transition_events(current_instances, target_instances) + events = get_transition_events(current_instances, target_instances, {}) # assert assert len(events) == 1 @@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance): target_instances: dict[InstanceId, Instance] = {} # act - events = get_transition_events(current_instances, target_instances) + events = get_transition_events(current_instances, target_instances, {}) # assert assert len(events) == 1 diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 115df719..b4b29df0 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand): instance_id: InstanceId +class TaskCancelled(BaseCommand): + cancelled_command_id: CommandId + + class TaskFinished(BaseCommand): finished_command_id: CommandId @@ -84,6 +88,7 @@ Command = ( | PlaceInstance | CreateInstance | DeleteInstance + | TaskCancelled | TaskFinished | SendInputChunk ) diff --git a/src/exo/shared/types/tasks.py b/src/exo/shared/types/tasks.py index 91fb44b1..8d866456 100644 --- a/src/exo/shared/types/tasks.py +++ b/src/exo/shared/types/tasks.py @@ -24,6 +24,7 @@ class TaskStatus(str, Enum): Complete = "Complete" TimedOut = "TimedOut" Failed = "Failed" + Cancelled = "Cancelled" class BaseTask(TaggedModel): @@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master error_message: str | None = Field(default=None) +class CancelTask(BaseTask): + cancelled_task_id: TaskId + runner_id: RunnerId + + class ImageGeneration(BaseTask): # emitted by Master command_id: CommandId task_params: ImageGenerationTaskParams @@ -87,6 +93,7 @@ Task = ( | LoadModel | StartWarmup | TextGeneration + | CancelTask | ImageGeneration | ImageEdits | Shutdown diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py index ebf0165f..646ac8f6 100644 --- a/src/exo/utils/channels.py +++ b/src/exo/utils/channels.py @@ -125,7 +125,9 @@ class MpSender[T]: self._state.buffer.put(item, block=True) async def send_async(self, item: T) -> None: - await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1)) + await to_thread.run_sync( + self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True + ) def close(self) -> None: if not self._state.closed.is_set(): diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 67a31ae0..fc000d98 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -32,7 +32,6 @@ from exo.worker.engines.mlx.constants import ( ) from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, - mx_barrier, ) from exo.worker.runner.bootstrap import logger @@ -136,10 +135,6 @@ def warmup_inference( logger.info("Generated ALL warmup tokens") - # TODO: Do we want an mx_barrier? - # At least this version is actively incorrect, as it should use mx_barrier(group) - mx_barrier() - return tokens_generated @@ -403,5 +398,3 @@ def mlx_generate( # Limit accumulated_text to what's needed for stop sequence detection if max_stop_len > 0 and len(accumulated_text) > max_stop_len: accumulated_text = accumulated_text[-max_stop_len:] - - # TODO: Do we want an mx_barrier? diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 4f1140fb..7e9bf3ed 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -67,8 +67,6 @@ Group = mx.distributed.Group resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) -# TODO: Test this -# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673 def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: return Memory.from_float_kb( (model_shard_meta.end_layer - model_shard_meta.start_layer) @@ -86,30 +84,6 @@ class ModelLoadingTimeoutError(Exception): pass -def mx_barrier(group: Group | None = None): - mx.eval( - mx.distributed.all_sum( - mx.array(1.0), - stream=mx.default_stream(mx.Device(mx.cpu)), - group=group, - ) - ) - - -def broadcast_from_zero(value: int, group: Group | None = None): - if group is None: - return value - - if group.rank() == 0: - a = mx.array([value], dtype=mx.int32) - else: - a = mx.array([0], dtype=mx.int32) - - m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group) - mx.eval(m) - return int(m.item()) - - class HostList(RootModel[list[str]]): @classmethod def from_hosts(cls, hosts: list[Host]) -> "HostList": @@ -562,3 +536,23 @@ def mlx_cleanup( import gc gc.collect() + + +def mx_any(bool_: bool, group: Group | None) -> bool: + if group is None: + return bool_ + num_true = mx.distributed.all_sum( + mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu)) + ) + mx.eval(num_true) + return num_true.item() > 0 + + +def mx_barrier(group: Group | None): + if group is None: + return + mx.eval( + mx.distributed.all_sum( + mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu)) + ) + ) diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 35834360..9811e362 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -32,6 +32,7 @@ from exo.shared.types.events import ( from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.state import State from exo.shared.types.tasks import ( + CancelTask, CreateRunner, DownloadModel, ImageEdits, @@ -218,15 +219,22 @@ class Worker: ) ) case Shutdown(runner_id=runner_id): + runner = self.runners.pop(runner_id) try: with fail_after(3): - await self.runners.pop(runner_id).start_task(task) + await runner.start_task(task) except TimeoutError: await self.event_sender.send( TaskStatusUpdated( task_id=task.task_id, task_status=TaskStatus.TimedOut ) ) + finally: + runner.shutdown() + case CancelTask( + cancelled_task_id=cancelled_task_id, runner_id=runner_id + ): + await self.runners[runner_id].cancel_task(cancelled_task_id) case ImageEdits() if task.task_params.total_input_chunks > 0: # Assemble image from chunks and inject into task cmd_id = task.command_id @@ -264,18 +272,18 @@ class Worker: del self.input_chunk_buffer[cmd_id] if cmd_id in self.input_chunk_counts: del self.input_chunk_counts[cmd_id] - await self.runners[self._task_to_runner_id(task)].start_task( - modified_task - ) + await self._start_runner_task(modified_task) case task: - await self.runners[self._task_to_runner_id(task)].start_task(task) + await self._start_runner_task(task) def shutdown(self): self._tg.cancel_scope.cancel() - def _task_to_runner_id(self, task: Task): - instance = self.state.instances[task.instance_id] - return instance.shard_assignments.node_to_runner[self.node_id] + async def _start_runner_task(self, task: Task): + if (instance := self.state.instances.get(task.instance_id)) is not None: + await self.runners[ + instance.shard_assignments.node_to_runner[self.node_id] + ].start_task(task) async def _nack_request(self, since_idx: int) -> None: # We request all events after (and including) the missing index. @@ -314,8 +322,6 @@ class Worker: for event in self.out_for_delivery.copy().values(): await self.local_event_sender.send(event) - ## Op Executors - def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor: """Creates and stores a new AssignedRunner with initial downloading status.""" runner = RunnerSupervisor.create( diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index c173c143..ce2eb4a9 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from exo.shared.types.common import CommandId, NodeId from exo.shared.types.tasks import ( + CancelTask, ConnectToGroup, CreateRunner, DownloadModel, @@ -53,13 +54,14 @@ def plan( ) -> Task | None: # Python short circuiting OR logic should evaluate these sequentially. return ( - _kill_runner(runners, all_runners, instances) + _cancel_tasks(runners, tasks) + or _kill_runner(runners, all_runners, instances) or _create_runner(node_id, runners, instances) or _model_needs_download(node_id, runners, global_download_status) or _init_distributed_backend(runners, all_runners) or _load_model(runners, all_runners, global_download_status) or _ready_to_warmup(runners, all_runners) - or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer) + or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {}) ) @@ -270,7 +272,7 @@ def _pending_tasks( runners: Mapping[RunnerId, RunnerSupervisor], tasks: Mapping[TaskId, Task], all_runners: Mapping[RunnerId, RunnerStatus], - input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None, + input_chunk_buffer: Mapping[CommandId, dict[int, str]], ) -> Task | None: for task in tasks.values(): # for now, just forward chat completions @@ -284,7 +286,7 @@ def _pending_tasks( if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0: cmd_id = task.command_id expected = task.task_params.total_input_chunks - received = len((input_chunk_buffer or {}).get(cmd_id, {})) + received = len(input_chunk_buffer.get(cmd_id, {})) if received < expected: continue # Wait for all chunks to arrive @@ -292,16 +294,33 @@ def _pending_tasks( if task.instance_id != runner.bound_instance.instance.instance_id: continue - # I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough - # however, realistically the task status should be set to completed by the LAST runner, so this is a true race - # the actual solution is somewhat deeper than this bypass - TODO! + # the task status _should_ be set to completed by the LAST runner + # it is currently set by the first + # this is definitely a hack if task.task_id in runner.completed: continue - # TODO: Check ordering aligns with MLX distributeds expectations. - if isinstance(runner.status, RunnerReady) and all( isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning)) for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard ): return task + + +def _cancel_tasks( + runners: Mapping[RunnerId, RunnerSupervisor], + tasks: Mapping[TaskId, Task], +) -> Task | None: + for task in tasks.values(): + if task.task_status != TaskStatus.Cancelled: + continue + for runner_id, runner in runners.items(): + if task.instance_id != runner.bound_instance.instance.instance_id: + continue + if task.task_id in runner.cancelled: + continue + return CancelTask( + instance_id=task.instance_id, + cancelled_task_id=task.task_id, + runner_id=runner_id, + ) diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index bf08ab6c..ed420aab 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -3,7 +3,7 @@ import os import loguru from exo.shared.types.events import Event, RunnerStatusUpdated -from exo.shared.types.tasks import Task +from exo.shared.types.tasks import Task, TaskId from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance from exo.shared.types.worker.runners import RunnerFailed from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender @@ -15,6 +15,7 @@ def entrypoint( bound_instance: BoundInstance, event_sender: MpSender[Event], task_receiver: MpReceiver[Task], + cancel_receiver: MpReceiver[TaskId], _logger: "loguru.Logger", ) -> None: fast_synch_override = os.environ.get("EXO_FAST_SYNCH") @@ -38,7 +39,7 @@ def entrypoint( try: from exo.worker.runner.runner import main - main(bound_instance, event_sender, task_receiver) + main(bound_instance, event_sender, task_receiver, cancel_receiver) except ClosedResourceError: logger.warning("Runner communication closed unexpectedly") except Exception as e: diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index b0e655cc..527df3e5 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -1,5 +1,6 @@ import base64 import json +import math import time from collections.abc import Generator from functools import cache @@ -87,6 +88,7 @@ from exo.worker.engines.mlx.utils_mlx import ( initialize_mlx, load_mlx_items, mlx_force_oom, + mx_any, ) from exo.worker.runner.bootstrap import logger @@ -111,6 +113,7 @@ def main( bound_instance: BoundInstance, event_sender: MpSender[Event], task_receiver: MpReceiver[Task], + cancel_receiver: MpReceiver[TaskId], ): instance, runner_id, shard_metadata = ( bound_instance.instance, @@ -125,11 +128,15 @@ def main( time.sleep(timeout) setup_start_time = time.time() + cancelled_tasks = set[TaskId]() - model: Model | DistributedImageModel | None = None + # type checker was unhappy with me - splitting these fixed it + inference_model: Model | None = None + image_model: DistributedImageModel | None = None tokenizer = None group = None kv_prefix_cache: KVPrefixCache | None = None + check_for_cancel_every: int | None = None current_status: RunnerStatus = RunnerIdle() logger.info("runner created") @@ -142,6 +149,7 @@ def main( if task.task_id in seen: logger.warning("repeat task - potential error") seen.add(task.task_id) + cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK")) event_sender.send( TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running) ) @@ -187,7 +195,7 @@ def main( time.sleep(0.5) if ModelTask.TextGeneration in shard_metadata.model_card.tasks: - model, tokenizer = load_mlx_items( + inference_model, tokenizer = load_mlx_items( bound_instance, group, on_timeout=on_model_load_timeout ) logger.info( @@ -199,7 +207,7 @@ def main( ModelTask.TextToImage in shard_metadata.model_card.tasks or ModelTask.ImageToImage in shard_metadata.model_card.tasks ): - model = initialize_image_model(bound_instance) + image_model = initialize_image_model(bound_instance) else: raise ValueError( f"Unknown model task(s): {shard_metadata.model_card.tasks}" @@ -207,8 +215,6 @@ def main( current_status = RunnerLoaded() logger.info("runner loaded") case StartWarmup() if isinstance(current_status, RunnerLoaded): - assert model - current_status = RunnerWarmingUp() logger.info("runner warming up") event_sender.send( @@ -220,15 +226,30 @@ def main( logger.info(f"warming up inference for instance: {instance}") if ModelTask.TextGeneration in shard_metadata.model_card.tasks: - assert not isinstance(model, DistributedImageModel) + assert inference_model assert tokenizer + t = time.perf_counter() toks = warmup_inference( - model=model, + model=inference_model, tokenizer=tokenizer, - # kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching ) logger.info(f"warmed up by generating {toks} tokens") + check_for_cancel_every = min( + math.ceil(toks / (time.perf_counter() - t)), 100 + ) + if group is not None: + check_for_cancel_every = int( + mx.max( + mx.distributed.all_gather( + mx.array([check_for_cancel_every]), group=group + ) + ).item() + ) + + logger.info( + f"runner checking for cancellation every {check_for_cancel_every} tokens" + ) logger.info( f"runner initialized in {time.time() - setup_start_time} seconds" ) @@ -236,8 +257,8 @@ def main( ModelTask.TextToImage in shard_metadata.model_card.tasks or ModelTask.ImageToImage in shard_metadata.model_card.tasks ): - assert isinstance(model, DistributedImageModel) - image = warmup_image_generator(model=model) + assert image_model + image = warmup_image_generator(model=image_model) if image is not None: logger.info(f"warmed up by generating {image.size} image") else: @@ -257,9 +278,9 @@ def main( ) ) event_sender.send(TaskAcknowledged(task_id=task.task_id)) - - assert model and not isinstance(model, DistributedImageModel) + assert inference_model assert tokenizer + assert check_for_cancel_every try: _check_for_debug_prompts(task_params) @@ -269,7 +290,7 @@ def main( # Generate responses using the actual MLX generation mlx_generator = mlx_generate( - model=model, + model=inference_model, tokenizer=tokenizer, task=task_params, prompt=prompt, @@ -293,11 +314,11 @@ def main( patch_glm_tokenizer(tokenizer) # GPT-OSS specific parsing to match other model formats. - elif isinstance(model, GptOssModel): + elif isinstance(inference_model, GptOssModel): mlx_generator = parse_gpt_oss(mlx_generator) if tokenizer.has_tool_calling and not isinstance( - model, GptOssModel + inference_model, GptOssModel ): assert tokenizer.tool_call_start assert tokenizer.tool_call_end @@ -310,7 +331,18 @@ def main( ) completion_tokens = 0 + tokens_since_last_cancel_check = 0 for response in mlx_generator: + tokens_since_last_cancel_check += 1 + if tokens_since_last_cancel_check >= check_for_cancel_every: + tokens_since_last_cancel_check = 0 + cancelled_tasks.update(cancel_receiver.collect()) + want_to_cancel = (task.task_id in cancelled_tasks) or ( + TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks + ) + if mx_any(want_to_cancel, group): + break + match response: case GenerationResponse(): completion_tokens += 1 @@ -382,7 +414,7 @@ def main( case ImageGeneration( task_params=task_params, command_id=command_id ) if isinstance(current_status, RunnerReady): - assert isinstance(model, DistributedImageModel) + assert image_model logger.info(f"received image generation request: {str(task)[:500]}") current_status = RunnerRunning() logger.info("runner running") @@ -395,7 +427,9 @@ def main( try: image_index = 0 - for response in generate_image(model=model, task=task_params): + for response in generate_image( + model=image_model, task=task_params + ): is_primary_output = _is_primary_output_node(shard_metadata) if is_primary_output: @@ -445,7 +479,7 @@ def main( case ImageEdits(task_params=task_params, command_id=command_id) if ( isinstance(current_status, RunnerReady) ): - assert isinstance(model, DistributedImageModel) + assert image_model logger.info(f"received image edits request: {str(task)[:500]}") current_status = RunnerRunning() logger.info("runner running") @@ -458,7 +492,9 @@ def main( try: image_index = 0 - for response in generate_image(model=model, task=task_params): + for response in generate_image( + model=image_model, task=task_params + ): if _is_primary_output_node(shard_metadata): match response: case PartialImageResponse(): @@ -524,7 +560,7 @@ def main( RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) ) if isinstance(current_status, RunnerShutdown): - del model, tokenizer, group + del inference_model, image_model, tokenizer, group mx.clear_cache() import gc diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 3b859711..5d39a881 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -47,9 +47,11 @@ class RunnerSupervisor: _ev_recv: MpReceiver[Event] _task_sender: MpSender[Task] _event_sender: Sender[Event] + _cancel_sender: MpSender[TaskId] status: RunnerStatus = field(default_factory=RunnerIdle, init=False) pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False) completed: set[TaskId] = field(default_factory=set, init=False) + cancelled: set[TaskId] = field(default_factory=set, init=False) @classmethod def create( @@ -60,8 +62,8 @@ class RunnerSupervisor: initialize_timeout: float = 400, ) -> Self: ev_send, ev_recv = mp_channel[Event]() - # A task is kind of a runner command task_sender, task_recv = mp_channel[Task]() + cancel_sender, cancel_recv = mp_channel[TaskId]() runner_process = Process( target=entrypoint, @@ -69,6 +71,7 @@ class RunnerSupervisor: bound_instance, ev_send, task_recv, + cancel_recv, logger, ), daemon=True, @@ -83,6 +86,7 @@ class RunnerSupervisor: initialize_timeout=initialize_timeout, _ev_recv=ev_recv, _task_sender=task_sender, + _cancel_sender=cancel_sender, _event_sender=event_sender, ) @@ -97,6 +101,8 @@ class RunnerSupervisor: self._ev_recv.close() self._task_sender.close() self._event_sender.close() + self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK")) + self._cancel_sender.close() self.runner_process.join(1) if not self.runner_process.is_alive(): logger.info("Runner process succesfully terminated") @@ -112,14 +118,6 @@ class RunnerSupervisor: logger.critical("Runner process didn't respond to SIGTERM, killing") self.runner_process.kill() - self.runner_process.join(1) - if not self.runner_process.is_alive(): - return - - logger.critical( - "Runner process didn't respond to SIGKILL. System resources may have leaked" - ) - async def start_task(self, task: Task): if task.task_id in self.pending: logger.warning( @@ -141,6 +139,17 @@ class RunnerSupervisor: return await event.wait() + async def cancel_task(self, task_id: TaskId): + if task_id in self.completed: + logger.info(f"Unable to cancel {task_id} as it has been completed") + return + self.cancelled.add(task_id) + with anyio.move_on_after(0.5) as scope: + await self._cancel_sender.send_async(task_id) + if scope.cancel_called: + logger.error("RunnerSupervisor cancel pipe blocked") + await self._check_runner(TimeoutError("cancel pipe blocked")) + async def _forward_events(self): with self._ev_recv as events: try: diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index edf5ef3a..b582efa0 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from typing import Callable +import mlx.core as mx import pytest import exo.worker.runner.runner as mlx_runner @@ -19,6 +20,7 @@ from exo.shared.types.tasks import ( Shutdown, StartWarmup, Task, + TaskId, TaskStatus, TextGeneration, ) @@ -113,6 +115,8 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer))) monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin) + monkeypatch.setattr(mlx_runner, "mx.all_gather", make_nothin(mx.array([1]))) + monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False)) # Mock apply_chat_template since we're using a fake tokenizer (integer 1). # Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None. monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt")) @@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]): ) task_sender, task_receiver = mp_channel[Task]() + _cancel_sender, cancel_receiver = mp_channel[TaskId]() event_sender = EventCollector() with task_sender: @@ -174,7 +179,7 @@ def _run(tasks: Iterable[Task]): task_receiver.close = nothin task_receiver.join = nothin - mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type] + mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType] return event_sender.events diff --git a/tests/run_exo_on.sh b/tests/run_exo_on.sh index 6dcd62d9..3cbc3bc0 100755 --- a/tests/run_exo_on.sh +++ b/tests/run_exo_on.sh @@ -35,7 +35,7 @@ i=0 for host; do colour=${colours[i++ % 4]} ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \ - "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |& + "EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |& awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' & done