Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
45d519cde7 workin on it 2026-01-24 02:39:38 +00:00
12 changed files with 158 additions and 80 deletions

View File

@@ -5,16 +5,16 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.

View File

@@ -75,6 +75,7 @@ from exo.shared.types.chunks import (
ToolCallChunk,
)
from exo.shared.types.commands import (
CancelTask,
ChatCompletion,
Command,
CreateInstance,
@@ -501,12 +502,10 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
command = CancelTask(cancelled_command_id=command_id)
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
command = TaskFinished(finished_command_id=command_id)
@@ -894,6 +893,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = CancelTask(cancelled_command_id=command_id)
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -975,6 +978,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = CancelTask(cancelled_command_id=command_id)
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -12,6 +12,7 @@ from exo.master.placement import (
)
from exo.shared.apply import apply
from exo.shared.types.commands import (
CancelTask,
ChatCompletion,
CreateInstance,
DeleteInstance,
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -278,6 +280,15 @@ class Master:
chunk=chunk,
)
)
case CancelTask():
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=self.command_task_mapping[
command.cancelled_command_id
],
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -286,10 +297,7 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
del self.command_task_mapping[command.finished_command_id]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class CancelTask(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -84,6 +88,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| CancelTask
| TaskFinished
| SendInputChunk
)

View File

@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -60,6 +61,10 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +92,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -23,7 +23,6 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -90,10 +89,6 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -186,5 +181,3 @@ def mlx_generate(
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -70,8 +70,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -89,30 +87,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -536,3 +510,33 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -229,6 +230,10 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -351,8 +356,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -2,8 +2,11 @@
from collections.abc import Mapping, Sequence
from loguru import logger
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -48,8 +51,8 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]] = {},
input_chunk_counts: Mapping[CommandId, int] = {},
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -60,6 +63,7 @@ def plan(
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _cancel_tasks(runners, tasks)
)
@@ -270,7 +274,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +288,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +296,33 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
logger.info(f"{task.task_id} is cancelled!")
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
logger.info(f"{task.task_id} is mine!")
if task.task_id in runner.cancelled:
logger.info(f"{task.task_id} already cancelled")
continue
logger.info(f"cancelling {task.task_id}")
CancelTask(instance_id=task.instance_id, cancelled_task_id=task.task_id)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,4 +1,5 @@
import base64
import contextlib
import json
import time
from collections.abc import Generator
@@ -37,6 +38,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -62,7 +64,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.utils.channels import MpReceiver, MpSender, WouldBlock
from exo.worker.engines.image import (
DistributedImageModel,
generate_image,
@@ -77,6 +79,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -85,6 +88,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -99,8 +103,11 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
group = None
@@ -111,6 +118,10 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
with contextlib.suppress(WouldBlock):
cancelled_tasks.add(cancel_receiver.receive_nowait())
if task.task_id in cancelled_tasks:
continue
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -155,7 +166,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -165,7 +176,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -174,8 +185,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -186,11 +195,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -202,8 +211,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -222,7 +231,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert task_params.messages[0].content is not None
@@ -234,14 +243,14 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
@@ -272,6 +281,12 @@ def main(
)
for response in mlx_generator:
with contextlib.suppress(WouldBlock):
cancelled_tasks.add(cancel_receiver.receive_nowait())
want_to_cancel = task.task_id in cancelled_tasks
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -338,7 +353,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -352,7 +367,9 @@ def main(
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -399,7 +416,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -411,7 +428,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -474,7 +493,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -49,10 +49,12 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -63,8 +65,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -72,6 +74,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -86,6 +89,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -131,6 +135,7 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -140,7 +145,13 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
async def _forward_events(self):
with self._ev_recv as events: