From 5fd080a2469c2c3f88491d127e7640ded423a645 Mon Sep 17 00:00:00 2001 From: Evan Date: Tue, 23 Dec 2025 16:54:02 +0000 Subject: [PATCH] wuff --- pyproject.toml | 1 + src/exo/worker/engines/mlx/constants.py | 1 - src/exo/worker/engines/mlx/utils_mlx.py | 46 ++-- src/exo/worker/runner/runner.py | 19 +- src/exo/worker/tests/constants.py | 6 + src/exo/worker/tests/unittests/conftest.py | 20 ++ .../test_runner/test_event_ordering.py | 199 ++++++++++++++++++ .../test_runner/test_runner_supervisor.py | 1 + uv.lock | 15 ++ 9 files changed, 271 insertions(+), 37 deletions(-) create mode 100644 src/exo/worker/tests/unittests/test_runner/test_event_ordering.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py diff --git a/pyproject.toml b/pyproject.toml index 3c007f09..af919f6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "anyio==4.11.0", "bidict>=0.23.1", "mlx>=0.29.3", + "mlx[cpu]>=0.29.3; sys_platform == 'linux'", "mlx-lm>=0.28.3", "tiktoken>=0.12.0", # required for kimi k2 tokenizer "hypercorn>=0.18.0", diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py index 9b5db542..0c17c97d 100644 --- a/src/exo/worker/engines/mlx/constants.py +++ b/src/exo/worker/engines/mlx/constants.py @@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600 QUANTIZE_MODEL_MODE: str | None = "affine" CACHE_GROUP_SIZE: int = 64 KV_CACHE_BITS: int | None = 8 -TEMPERATURE: float = 1.0 # TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True TRUST_REMOTE_CODE: bool = True diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 19d565ca..e7edf6a1 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -5,6 +5,7 @@ import time from pathlib import Path from typing import Any, Callable, cast +from mlx.core.distributed import Group from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.sample_utils import make_sampler @@ -13,7 +14,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper from exo.worker.engines.mlx.constants import ( CACHE_GROUP_SIZE, KV_CACHE_BITS, - TEMPERATURE, TRUST_REMOTE_CODE, ) @@ -67,7 +67,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: ) -def mx_barrier(group: mx.distributed.Group | None = None): +def mx_barrier(group: Group | None = None): mx.eval( mx.distributed.all_sum( mx.array(1.0), @@ -77,7 +77,7 @@ def mx_barrier(group: mx.distributed.Group | None = None): ) -def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None): +def broadcast_from_zero(value: int, group: Group | None = None): if group is None: return value @@ -99,15 +99,9 @@ class HostList(RootModel[list[str]]): def mlx_distributed_init( bound_instance: BoundInstance, -) -> mx.distributed.Group: +) -> Group: """ - Initialize the MLX distributed (runs in thread pool). - - Either hosts or mlx_ibv_devices must be provided: - - hosts: traditional host-based connectivity using MLX_HOSTFILE - - mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES - - mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup - - strict: if True, raise an error if the distributed backend is not available + Initialize MLX distributed. """ rank = bound_instance.bound_shard.device_rank logger.info(f"Starting initialization for rank {rank}") @@ -154,36 +148,34 @@ def mlx_distributed_init( def initialize_mlx( bound_instance: BoundInstance, -) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]: - """ - Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread. - """ +) -> Group | None: + # should we unseed it? + # TODO: pass in seed from params mx.random.seed(42) - set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) + if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1: + return None + return mlx_distributed_init(bound_instance) - sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE) + +def load_mlx_items( + bound_instance: BoundInstance, group: Group | None +) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]: + # TODO: pass temperature + sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7) logger.info("Created a sampler") - if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1: + if group is None: logger.info(f"Single device used for {bound_instance.instance}") model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id) start_time = time.perf_counter() model, _ = load_model(model_path, strict=True) end_time = time.perf_counter() logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") - if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore - pass - # model, config = quantize_model( - # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE - # ) - tokenizer = get_tokenizer(model_path, bound_instance.bound_shard) else: logger.info("Starting distributed init") - group = mlx_distributed_init(bound_instance) - start_time = time.perf_counter() model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group) end_time = time.perf_counter() @@ -193,8 +185,6 @@ def initialize_mlx( set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) - logger.debug(model) - return cast(Model, model), tokenizer, sampler diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index 54e32605..c0b20e40 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -39,6 +39,7 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference from exo.worker.engines.mlx.utils_mlx import ( initialize_mlx, + load_mlx_items, mlx_force_oom, ) from exo.worker.runner.bootstrap import logger @@ -66,9 +67,10 @@ def main( model = None tokenizer = None sampler = None + group = None current_status: RunnerStatus = RunnerIdle() - logger.info("runner waiting for model") + logger.info("runner created") event_sender.send( RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) ) @@ -81,7 +83,9 @@ def main( ) event_sender.send(TaskAcknowledged(task_id=task.task_id)) match task: - case ConnectToGroup() if isinstance (RunnerIdle, RunnerFailed): + case ConnectToGroup() if isinstance( + current_status, (RunnerIdle, RunnerFailed) + ): logger.info("runner connecting") current_status = RunnerConnecting() event_sender.send( @@ -89,15 +93,12 @@ def main( runner_id=runner_id, runner_status=current_status ) ) - model, tokenizer, sampler = initialize_mlx(bound_instance) + group = initialize_mlx(bound_instance) logger.info("runner connected") current_status = RunnerConnected() - - case LoadModel() if isinstance( - current_status, RunnerConnected - ): + case LoadModel() if isinstance(current_status, RunnerConnected): current_status = RunnerLoading() logger.info("runner loading") event_sender.send( @@ -106,7 +107,9 @@ def main( ) ) - model, tokenizer, sampler = initialize_mlx(bound_instance) + model, tokenizer, sampler = load_mlx_items( + bound_instance, group + ) current_status = RunnerLoaded() logger.info("runner loaded") diff --git a/src/exo/worker/tests/constants.py b/src/exo/worker/tests/constants.py index 787f2ff7..6d7fabe7 100644 --- a/src/exo/worker/tests/constants.py +++ b/src/exo/worker/tests/constants.py @@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777") COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888") + +SHUTDOWN_TASK_ID = TaskId("shutdown") +CHAT_COMPLETION_TASK_ID = TaskId("chat-completion") +INITIALIZATION_TASK_ID = TaskId("initialisation") +LOAD_TASK_ID = TaskId("load") +WARMUP_TASK_ID = TaskId("warmup") diff --git a/src/exo/worker/tests/unittests/conftest.py b/src/exo/worker/tests/unittests/conftest.py index 48fc387a..dde0058b 100644 --- a/src/exo/worker/tests/unittests/conftest.py +++ b/src/exo/worker/tests/unittests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from exo.shared.types.common import NodeId @@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata +# Runner supervisor without multiprocessing logic. @dataclass(frozen=True) class FakeRunnerSupervisor: bound_instance: BoundInstance @@ -35,6 +38,8 @@ def get_pipeline_shard_metadata( pretty_name=str(model_id), storage_size=Memory.from_mb(100000), n_layers=32, + # hidden_size=2048, + # supports_tensor=False, ), device_rank=device_rank, world_size=world_size, @@ -69,3 +74,18 @@ def get_mlx_ring_instance( ), hosts=[], ) + + +def get_bound_mlx_ring_instance( + instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId +) -> BoundInstance: + shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=1) + instance = get_mlx_ring_instance( + instance_id=instance_id, + model_id=model_id, + node_to_runner={node_id: runner_id}, + runner_to_shard={runner_id: shard}, + ) + return BoundInstance( + instance=instance, bound_runner_id=runner_id, bound_node_id=node_id + ) diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py new file mode 100644 index 00000000..e1b78e64 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -0,0 +1,199 @@ +# Check tasks are complete before runner is ever ready. +from collections.abc import Iterable +import pytest + +import exo.worker.runner.runner as mlx_runner +from exo.shared.types.api import ChatCompletionMessage +from exo.shared.types.chunks import TokenChunk +from exo.shared.types.events import ( + Event, + ChunkGenerated, + RunnerStatusUpdated, + TaskAcknowledged, + TaskStatusUpdated, +) +from exo.shared.types.tasks import ( + ChatCompletion, + ChatCompletionTaskParams, + ConnectToGroup, + LoadModel, + Shutdown, + StartWarmup, + TaskStatus, + Task, +) +from exo.shared.types.worker.runner_response import GenerationResponse +from exo.shared.types.worker.runners import ( + RunnerIdle, + RunnerLoaded, + RunnerLoading, + RunnerReady, + RunnerRunning, + RunnerShutdown, + RunnerConnecting, + RunnerConnected, + RunnerWarmingUp, +) +from exo.utils.channels import mp_channel + +from ...constants import ( + CHAT_COMPLETION_TASK_ID, + COMMAND_1_ID, + INITIALIZATION_TASK_ID, + INSTANCE_1_ID, + LOAD_TASK_ID, + MODEL_A_ID, + NODE_A, + RUNNER_1_ID, + SHUTDOWN_TASK_ID, + WARMUP_TASK_ID, +) +from ..conftest import get_bound_mlx_ring_instance + + +INIT_TASK = ConnectToGroup( + task_id=INITIALIZATION_TASK_ID, + instance_id=INSTANCE_1_ID, +) + +LOAD_TASK = LoadModel( + task_id=LOAD_TASK_ID, + instance_id=INSTANCE_1_ID, +) + +WARMUP_TASK = StartWarmup( + task_id=WARMUP_TASK_ID, + instance_id=INSTANCE_1_ID, +) + +SHUTDOWN_TASK = Shutdown( + task_id=SHUTDOWN_TASK_ID, + instance_id=INSTANCE_1_ID, + runner_id=RUNNER_1_ID, +) + +CHAT_PARAMS = ChatCompletionTaskParams( + model=str(MODEL_A_ID), + messages=[ChatCompletionMessage(role="user", content="hello")], + stream=True, + max_tokens=4, + temperature=0.0, +) + +CHAT_TASK = ChatCompletion( + task_id=CHAT_COMPLETION_TASK_ID, + command_id=COMMAND_1_ID, + task_params=CHAT_PARAMS, + instance_id=INSTANCE_1_ID, +) + + +def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]): + for test_event, true_event in zip(test_events, true_events, strict=True): + test_event.event_id = true_event.event_id + assert test_event == true_event, f"{test_event} != {true_event}" + + +@pytest.fixture +def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(mlx_runner, "initialize_mlx", lambda bound_instance: object()) + monkeypatch.setattr( + mlx_runner, + "load_mlx_items", + lambda bound_instance, group: (object(), object(), object()), + ) + monkeypatch.setattr(mlx_runner, "warmup_inference", lambda **kwargs: 1) + monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", lambda *_: None) + + def fake_generate(model, tokenizer, sampler, task): + yield GenerationResponse(token=0, text="hi", finish_reason="stop") + + monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate) + + +def _run(tasks: Iterable[Task]): + bound_instance = get_bound_mlx_ring_instance( + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + runner_id=RUNNER_1_ID, + node_id=NODE_A, + ) + + task_sender, task_receiver = mp_channel[Task]() + event_sender, event_receiver = mp_channel[Event]() + + with task_sender, event_receiver: + for t in tasks: + task_sender.send(t) + + # worst monkeypatch known to man + def nothin() -> None: pass + event_sender.close = nothin + event_sender.join = nothin + task_receiver.close = nothin + task_receiver.join = nothin + + mlx_runner.main(bound_instance, event_sender, task_receiver) + + return event_receiver.collect() + + +def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch): + events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK]) + + expected_chunk = ChunkGenerated( + command_id=COMMAND_1_ID, + chunk=TokenChunk( + idx=0, + model=MODEL_A_ID, + text="hi", + token_id=0, + finish_reason="stop", + ), + ) + + assert_events_equal( + events, + [ + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()), + TaskStatusUpdated( + task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running + ), + TaskAcknowledged(task_id=INITIALIZATION_TASK_ID), + RunnerStatusUpdated( + runner_id=RUNNER_1_ID, runner_status=RunnerConnecting() + ), + TaskStatusUpdated( + task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete + ), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()), + TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running), + TaskAcknowledged(task_id=LOAD_TASK_ID), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()), + TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()), + TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running), + TaskAcknowledged(task_id=WARMUP_TASK_ID), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()), + TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()), + TaskStatusUpdated( + task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running + ), + TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID), + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()), + expected_chunk, + TaskStatusUpdated( + task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete + ), + # CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()), + TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running), + TaskAcknowledged(task_id=SHUTDOWN_TASK_ID), + TaskStatusUpdated( + task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete + ), + # SPECIAL EXCEPTION FOR RUNNER SHUTDOWN + RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()), + ], + ) diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py new file mode 100644 index 00000000..e151d4aa --- /dev/null +++ b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py @@ -0,0 +1 @@ +# TODO: diff --git a/uv.lock b/uv.lock index 50884363..f9fb2fa1 100644 --- a/uv.lock +++ b/uv.lock @@ -334,6 +334,7 @@ dependencies = [ { name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -375,6 +376,7 @@ requires-dist = [ { name = "hypercorn", specifier = ">=0.18.0" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "mlx", specifier = ">=0.29.3" }, + { name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.29.3" }, { name = "mlx-lm", specifier = ">=0.28.3" }, { name = "networkx", specifier = ">=3.5" }, { name = "protobuf", specifier = ">=6.32.0" }, @@ -795,6 +797,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" }, ] +[package.optional-dependencies] +cpu = [ + { name = "mlx-cpu", marker = "sys_platform == 'linux'" }, +] + +[[package]] +name = "mlx-cpu" +version = "0.29.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/ff/474abb13000ca641985084055c145a70c1214973d867979ebfe7420c2df2/mlx_cpu-0.29.3-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:e76763434a9d1d878bb0d6dd965ad319a0a63b0b1d69314e4c97d8332f5e7170", size = 10225301, upload-time = "2025-10-17T19:24:03.544Z" }, +] + [[package]] name = "mlx-lm" version = "0.28.3"