This commit is contained in:
Evan
2025-12-23 16:54:02 +00:00
parent 201c61f9cd
commit 5fd080a246
9 changed files with 271 additions and 37 deletions

View File

@@ -30,6 +30,7 @@ dependencies = [
"anyio==4.11.0", "anyio==4.11.0",
"bidict>=0.23.1", "bidict>=0.23.1",
"mlx>=0.29.3", "mlx>=0.29.3",
"mlx[cpu]>=0.29.3; sys_platform == 'linux'",
"mlx-lm>=0.28.3", "mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer "tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0", "hypercorn>=0.18.0",

View File

@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine" QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64 CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8 KV_CACHE_BITS: int | None = 8
TEMPERATURE: float = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True # TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True TRUST_REMOTE_CODE: bool = True

View File

@@ -5,6 +5,7 @@ import time
from pathlib import Path from pathlib import Path
from typing import Any, Callable, cast from typing import Any, Callable, cast
from mlx.core.distributed import Group
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_sampler
@@ -13,7 +14,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import ( from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE, CACHE_GROUP_SIZE,
KV_CACHE_BITS, KV_CACHE_BITS,
TEMPERATURE,
TRUST_REMOTE_CODE, TRUST_REMOTE_CODE,
) )
@@ -67,7 +67,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
) )
def mx_barrier(group: mx.distributed.Group | None = None): def mx_barrier(group: Group | None = None):
mx.eval( mx.eval(
mx.distributed.all_sum( mx.distributed.all_sum(
mx.array(1.0), mx.array(1.0),
@@ -77,7 +77,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
) )
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None): def broadcast_from_zero(value: int, group: Group | None = None):
if group is None: if group is None:
return value return value
@@ -99,15 +99,9 @@ class HostList(RootModel[list[str]]):
def mlx_distributed_init( def mlx_distributed_init(
bound_instance: BoundInstance, bound_instance: BoundInstance,
) -> mx.distributed.Group: ) -> Group:
""" """
Initialize the MLX distributed (runs in thread pool). Initialize MLX distributed.
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
""" """
rank = bound_instance.bound_shard.device_rank rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}") logger.info(f"Starting initialization for rank {rank}")
@@ -154,36 +148,34 @@ def mlx_distributed_init(
def initialize_mlx( def initialize_mlx(
bound_instance: BoundInstance, bound_instance: BoundInstance,
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]: ) -> Group | None:
""" # should we unseed it?
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread. # TODO: pass in seed from params
"""
mx.random.seed(42) mx.random.seed(42)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
return None
return mlx_distributed_init(bound_instance)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
logger.info("Created a sampler") logger.info("Created a sampler")
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1: if group is None:
logger.info(f"Single device used for {bound_instance.instance}") logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id) model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter() start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True) model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter() end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard) tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else: else:
logger.info("Starting distributed init") logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
start_time = time.perf_counter() start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group) model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter() end_time = time.perf_counter()
@@ -193,8 +185,6 @@ def initialize_mlx(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
logger.debug(model)
return cast(Model, model), tokenizer, sampler return cast(Model, model), tokenizer, sampler

View File

@@ -39,6 +39,7 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import ( from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx, initialize_mlx,
load_mlx_items,
mlx_force_oom, mlx_force_oom,
) )
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
@@ -66,9 +67,10 @@ def main(
model = None model = None
tokenizer = None tokenizer = None
sampler = None sampler = None
group = None
current_status: RunnerStatus = RunnerIdle() current_status: RunnerStatus = RunnerIdle()
logger.info("runner waiting for model") logger.info("runner created")
event_sender.send( event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
) )
@@ -81,7 +83,9 @@ def main(
) )
event_sender.send(TaskAcknowledged(task_id=task.task_id)) event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task: match task:
case ConnectToGroup() if isinstance (RunnerIdle, RunnerFailed): case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting") logger.info("runner connecting")
current_status = RunnerConnecting() current_status = RunnerConnecting()
event_sender.send( event_sender.send(
@@ -89,15 +93,12 @@ def main(
runner_id=runner_id, runner_status=current_status runner_id=runner_id, runner_status=current_status
) )
) )
model, tokenizer, sampler = initialize_mlx(bound_instance) group = initialize_mlx(bound_instance)
logger.info("runner connected") logger.info("runner connected")
current_status = RunnerConnected() current_status = RunnerConnected()
case LoadModel() if isinstance(current_status, RunnerConnected):
case LoadModel() if isinstance(
current_status, RunnerConnected
):
current_status = RunnerLoading() current_status = RunnerLoading()
logger.info("runner loading") logger.info("runner loading")
event_sender.send( event_sender.send(
@@ -106,7 +107,9 @@ def main(
) )
) )
model, tokenizer, sampler = initialize_mlx(bound_instance) model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
current_status = RunnerLoaded() current_status = RunnerLoaded()
logger.info("runner loaded") logger.info("runner loaded")

View File

@@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777") COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888") COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
SHUTDOWN_TASK_ID = TaskId("shutdown")
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
INITIALIZATION_TASK_ID = TaskId("initialisation")
LOAD_TASK_ID = TaskId("load")
WARMUP_TASK_ID = TaskId("warmup")

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
@@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
# Runner supervisor without multiprocessing logic.
@dataclass(frozen=True) @dataclass(frozen=True)
class FakeRunnerSupervisor: class FakeRunnerSupervisor:
bound_instance: BoundInstance bound_instance: BoundInstance
@@ -35,6 +38,8 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id), pretty_name=str(model_id),
storage_size=Memory.from_mb(100000), storage_size=Memory.from_mb(100000),
n_layers=32, n_layers=32,
# hidden_size=2048,
# supports_tensor=False,
), ),
device_rank=device_rank, device_rank=device_rank,
world_size=world_size, world_size=world_size,
@@ -69,3 +74,18 @@ def get_mlx_ring_instance(
), ),
hosts=[], hosts=[],
) )
def get_bound_mlx_ring_instance(
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
) -> BoundInstance:
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=1)
instance = get_mlx_ring_instance(
instance_id=instance_id,
model_id=model_id,
node_to_runner={node_id: runner_id},
runner_to_shard={runner_id: shard},
)
return BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
)

View File

@@ -0,0 +1,199 @@
# Check tasks are complete before runner is ever ready.
from collections.abc import Iterable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
Event,
ChunkGenerated,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
TaskStatus,
Task,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerConnecting,
RunnerConnected,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from ...constants import (
CHAT_COMPLETION_TASK_ID,
COMMAND_1_ID,
INITIALIZATION_TASK_ID,
INSTANCE_1_ID,
LOAD_TASK_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
SHUTDOWN_TASK_ID,
WARMUP_TASK_ID,
)
from ..conftest import get_bound_mlx_ring_instance
INIT_TASK = ConnectToGroup(
task_id=INITIALIZATION_TASK_ID,
instance_id=INSTANCE_1_ID,
)
LOAD_TASK = LoadModel(
task_id=LOAD_TASK_ID,
instance_id=INSTANCE_1_ID,
)
WARMUP_TASK = StartWarmup(
task_id=WARMUP_TASK_ID,
instance_id=INSTANCE_1_ID,
)
SHUTDOWN_TASK = Shutdown(
task_id=SHUTDOWN_TASK_ID,
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=4,
temperature=0.0,
)
CHAT_TASK = ChatCompletion(
task_id=CHAT_COMPLETION_TASK_ID,
command_id=COMMAND_1_ID,
task_params=CHAT_PARAMS,
instance_id=INSTANCE_1_ID,
)
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
for test_event, true_event in zip(test_events, true_events, strict=True):
test_event.event_id = true_event.event_id
assert test_event == true_event, f"{test_event} != {true_event}"
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "initialize_mlx", lambda bound_instance: object())
monkeypatch.setattr(
mlx_runner,
"load_mlx_items",
lambda bound_instance, group: (object(), object(), object()),
)
monkeypatch.setattr(mlx_runner, "warmup_inference", lambda **kwargs: 1)
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", lambda *_: None)
def fake_generate(model, tokenizer, sampler, task):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
def nothin() -> None: pass
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,
finish_reason="stop",
),
)
assert_events_equal(
events,
[
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
),
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -0,0 +1 @@
# TODO:

15
uv.lock generated
View File

@@ -334,6 +334,7 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -375,6 +376,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" }, { name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.29.3" }, { name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.29.3" },
{ name = "mlx-lm", specifier = ">=0.28.3" }, { name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" }, { name = "networkx", specifier = ">=3.5" },
{ name = "protobuf", specifier = ">=6.32.0" }, { name = "protobuf", specifier = ">=6.32.0" },
@@ -795,6 +797,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" }, { url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
] ]
[package.optional-dependencies]
cpu = [
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx-cpu"
version = "0.29.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6d/ff/474abb13000ca641985084055c145a70c1214973d867979ebfe7420c2df2/mlx_cpu-0.29.3-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:e76763434a9d1d878bb0d6dd965ad319a0a63b0b1d69314e4c97d8332f5e7170", size = 10225301, upload-time = "2025-10-17T19:24:03.544Z" },
]
[[package]] [[package]]
name = "mlx-lm" name = "mlx-lm"
version = "0.28.3" version = "0.28.3"