3 Commits

Author SHA1 Message Date
Matiwos Kebede
eabdcab978 Fix linux docs (#1022)
This PR updates the "Run from Source (Mac & Linux)" section in README.md
to clarify Linux instructions.

Changes include:
- Split the section into macOS and Linux subsections.
- Added native Linux package manager commands (apt, dnf, pacman) for
dependencies: uv, node, npm.
- Clarified that macmon is macOS-only.
- Noted that Homebrew on Linux is optional, with native package managers
preferred.

These changes improve clarity for Linux users and fix confusion from the
previous macOS-centric instructions.
2025-12-27 19:56:44 +00:00
Evan Quiney
8e9332d6a7 Separate out the Runner's behaviour into a "connect" phase and a "load" phase (#1006)
## Motivation

We should ensure all runners are connected before loading the model -
this gives us finer grained control in the future for the workers
planning mechanism over the runners state.

## Changes

- Introduced task ConnectToGroup, preceeding LoadModel
- Introduced runner statuses Idle, Connecting, Connected
- Separated out initialize_mlx from shard_and_load
- Single instances never go through the connecting phase

## Test Plan

# Automated Testing
Added a test for checking event ordering in a standard workflow.

# Manual testing
Tested Llama 3.2 1b and Kimi K2 Thinking loads and shuts down repeatedly
on multiple configurations.
Not exhaustive, however.

---------

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2025-12-27 16:28:42 +00:00
Heath Dutton🕴️
4b65d5f896 Fix race condition in mlx_distributed_init with concurrent instances (#1012)
## Motivation

Fixes #1005

When multiple instances initialize concurrently with the same rank, they
overwrite each other's coordination files (hosts_{rank}.json), causing
"[jaccl] Malformed device file" errors and initialization failures.

## Changes

- Changed coordination filename from `./hosts_{rank}.json` to
`./hosts_{instance_id}_{rank}.json` to make it unique per instance
- Added cleanup in a finally block to remove coordination files after
initialization completes
- Applied fix to both MlxRingInstance and MlxJacclInstance cases

## Why It Works

Each instance now gets a unique coordination file based on its
instance_id, preventing concurrent instances from overwriting each
other's files. The cleanup logic ensures files are removed after use,
preventing accumulation and handling both success and failure cases.

## Test Plan

### Manual Testing
Code review and logic verification. The fix prevents the race condition
by ensuring filename uniqueness per instance.

### Automated Testing
No new tests added. Existing tests continue to pass.

---------

Co-authored-by: Ryuichi Leo Takashige <rl.takashige@gmail.com>
2025-12-27 16:13:26 +00:00
21 changed files with 569 additions and 153 deletions

2
.gitignore vendored
View File

@@ -7,6 +7,8 @@ digest.txt
# nix
.direnv/
# IDEA (PyCharm)
.idea
# xcode / macos
*.xcuserstate

View File

@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
There are two ways to run exo:
### Run from Source (Mac & Linux)
### Run from Source (macOS)
**Prerequisites:**
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
@@ -98,6 +98,62 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
### Run from Source (Linux)
**Prerequisites:**
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
**Installation methods:**
**Option 1: Using system package manager (Ubuntu/Debian example):**
```bash
# Install Node.js and npm
sudo apt update
sudo apt install nodejs npm
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Option 2: Using Homebrew on Linux (if preferred):**
```bash
# Install Homebrew on Linux
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# Install dependencies
brew install uv node
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Note:** The `macmon` package is macOS-only and not required for Linux.
Clone the repo, build the dashboard, and run exo:
```bash
# Clone exo
git clone https://github.com/exo-explore/exo
# Build dashboard
cd exo/dashboard && npm install && npm run build && cd ..
# Run exo
uv run exo
```
This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
### macOS App
exo ships a macOS app that runs in the background on your Mac.

View File

@@ -29,7 +29,8 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
NodeDownloadProgress(download_progress=event), state
)
assert new_state == State(downloads={NodeId("node-1"): [event]})
assert new_state.downloads == {NodeId("node-1"): [event]}
def test_apply_two_node_download_progress():
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
pass
class ConnectToGroup(BaseTask): # emitted by Worker
pass
class StartWarmup(BaseTask): # emitted by Worker
pass
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
Task = (
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
CreateRunner
| DownloadModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
)

View File

@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
return isinstance(self, RunnerRunning)
class RunnerWaitingForModel(BaseRunnerStatus):
class RunnerIdle(BaseRunnerStatus):
pass
class RunnerConnecting(BaseRunnerStatus):
pass
class RunnerConnected(BaseRunnerStatus):
pass
@@ -54,7 +62,9 @@ class RunnerFailed(BaseRunnerStatus):
RunnerStatus = (
RunnerWaitingForModel
RunnerIdle
| RunnerConnecting
| RunnerConnected
| RunnerLoading
| RunnerLoaded
| RunnerWarmingUp

View File

@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
TEMPERATURE: float = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
@@ -21,6 +20,8 @@ try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
import contextlib
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.utils import load_model
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
# Needed for 8 bit model
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
def mx_barrier(group: mx.distributed.Group | None = None):
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
)
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
@@ -99,91 +101,96 @@ class HostList(RootModel[list[str]]):
def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
) -> Group:
"""
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
Initialize MLX distributed.
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts=hosts):
hostfile = f"./hosts_{rank}.json"
hosts_json = HostList.from_hosts(hosts).model_dump_json()
coordination_file = None
try:
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts=hosts):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_json = HostList.from_hosts(hosts).model_dump_json()
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
with open(coordination_file, "w") as f:
_ = f.write(hosts_json)
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
logger.info(
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
)
os.environ["MLX_HOSTFILE"] = hostfile
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
os.environ["MLX_HOSTFILE"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
case MlxJacclInstance(
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
ibv_devices_json = json.dumps(ibv_devices)
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
with open(coordination_file, "w") as f:
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(
bound_instance: BoundInstance,
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
"""
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
"""
) -> Group:
# should we unseed it?
# TODO: pass in seed from params
mx.random.seed(42)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
"Tried to initialize mlx for a single node instance"
)
return mlx_distributed_init(bound_instance)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
logger.info("Created a sampler")
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
@@ -193,14 +200,12 @@ def initialize_mlx(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
logger.debug(model)
return cast(Model, model), tokenizer, sampler
def shard_and_load(
shard_metadata: ShardMetadata,
group: mx.distributed.Group,
group: Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)

View File

@@ -228,7 +228,7 @@ class Worker:
)
)
else:
self.event_sender.send_nowait(
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
@@ -414,7 +414,7 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.node_id, self.state.topology)
conns = await check_reachable(self.state.topology)
for nid in conns:
for ip in conns[nid]:
edge = Connection(

View File

@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
LoadModel,
@@ -14,17 +15,23 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerId,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
@@ -48,6 +55,7 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
@@ -106,9 +114,11 @@ def _model_needs_download(
download_status: Mapping[ShardMetadata, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
if (
isinstance(runner.status, RunnerWaitingForModel)
and runner.bound_instance.bound_shard not in download_status
if isinstance(runner.status, RunnerIdle) and (
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
@@ -117,14 +127,54 @@ def _model_needs_download(
)
""" --- TODO!
def _init_backend(
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> LoadModel | None:
for runner in runner.values()
pass
"""
):
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
if is_single_node_instance:
continue
runner_is_idle = isinstance(runner.status, RunnerIdle)
all_runners_connecting = all(
isinstance(
all_runners.get(global_runner_id),
(RunnerConnecting, RunnerIdle),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if not (runner_is_idle and all_runners_connecting):
continue
runner_id = runner.bound_instance.bound_runner_id
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
world_size = shard.world_size
assert device_rank < world_size
assert device_rank >= 0
accepting_ranks = device_rank < world_size - 1
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
)
if not (accepting_ranks or connecting_rank_ready):
continue
return ConnectToGroup(instance_id=instance.instance_id)
return None
def _load_model(
@@ -136,31 +186,33 @@ def _load_model(
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
all_downloads_complete_local = all(
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid, rid in shard_assignments.node_to_runner.items()
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
all_runners_expecting_model = all(
is_runner_waiting = isinstance(runner.status, RunnerConnected)
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id),
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if (
all_downloads_complete_local
and runner_is_waiting
and all_runners_expecting_model
):
if is_runner_waiting and all_ready_for_model:
return LoadModel(instance_id=instance.instance_id)
return None
@@ -183,8 +235,9 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# Rank != n-1
accepting_ranks_ready = device_rank != world_size - 1 and all(
# TODO: Ensure these align with MLX distributeds expectations.
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -192,8 +245,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
@@ -221,6 +274,8 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard

View File

@@ -11,6 +11,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
@@ -22,20 +23,23 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@@ -63,9 +67,10 @@ def main(
model = None
tokenizer = None
sampler = None
group = None
current_status: RunnerStatus = RunnerWaitingForModel()
logger.info("runner waiting for model")
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
@@ -78,9 +83,26 @@ def main(
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(
current_status, (RunnerWaitingForModel, RunnerFailed)
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected)
and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
@@ -89,15 +111,12 @@ def main(
)
)
model, tokenizer, sampler = initialize_mlx(bound_instance)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
current_status = RunnerLoaded()
logger.info("runner loaded")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
@@ -123,11 +142,6 @@ def main(
)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
@@ -172,11 +186,6 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case Shutdown():
logger.info("runner shutting down")
event_sender.send(
@@ -186,12 +195,19 @@ def main(
)
break
case _:
raise ValueError("Received task outside of state machine")
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)

View File

@@ -19,8 +19,8 @@ from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerStatus,
RunnerWaitingForModel,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -41,7 +41,7 @@ class RunnerSupervisor:
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
@classmethod

View File

@@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
SHUTDOWN_TASK_ID = TaskId("shutdown")
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
INITIALIZATION_TASK_ID = TaskId("initialisation")
LOAD_TASK_ID = TaskId("load")
WARMUP_TASK_ID = TaskId("warmup")

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from exo.shared.types.common import NodeId
@@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
# Runner supervisor without multiprocessing logic.
@dataclass(frozen=True)
class FakeRunnerSupervisor:
bound_instance: BoundInstance
@@ -35,6 +38,8 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
# hidden_size=2048,
# supports_tensor=False,
),
device_rank=device_rank,
world_size=world_size,
@@ -69,3 +74,24 @@ def get_mlx_ring_instance(
),
hosts=[],
)
def get_bound_mlx_ring_instance(
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
) -> BoundInstance:
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
other_shard = get_pipeline_shard_metadata(
model_id=model_id, device_rank=1, world_size=2
)
instance = get_mlx_ring_instance(
instance_id=instance_id,
model_id=model_id,
node_to_runner={
node_id: runner_id,
NodeId("other_node"): RunnerId("other_runner"),
},
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
)
return BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
)

View File

@@ -4,7 +4,8 @@ from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerWaitingForModel,
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
@@ -38,13 +39,11 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
@@ -82,15 +81,15 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
@@ -133,13 +132,11 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Local status claims the shard is downloaded already
local_download_status = {
@@ -183,14 +180,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
@@ -213,3 +210,22 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
)
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is not None

View File

@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerReady,
RunnerRunning,
RunnerWaitingForModel,
)
from exo.worker.tests.constants import (
COMMAND_1_ID,
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerIdle(),
}
task = ChatCompletion(

View File

@@ -2,8 +2,8 @@ import exo.worker.plan as plan_mod
from exo.shared.types.tasks import StartWarmup
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoaded,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.worker.tests.constants import (
@@ -128,7 +128,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerIdle(),
RUNNER_2_ID: RunnerLoaded(),
}

View File

@@ -0,0 +1,208 @@
# Check tasks are complete before runner is ever ready.
from collections.abc import Iterable
from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from ...constants import (
CHAT_COMPLETION_TASK_ID,
COMMAND_1_ID,
INITIALIZATION_TASK_ID,
INSTANCE_1_ID,
LOAD_TASK_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
SHUTDOWN_TASK_ID,
WARMUP_TASK_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
nothin = make_nothin(None)
INIT_TASK = ConnectToGroup(
task_id=INITIALIZATION_TASK_ID,
instance_id=INSTANCE_1_ID,
)
LOAD_TASK = LoadModel(
task_id=LOAD_TASK_ID,
instance_id=INSTANCE_1_ID,
)
WARMUP_TASK = StartWarmup(
task_id=WARMUP_TASK_ID,
instance_id=INSTANCE_1_ID,
)
SHUTDOWN_TASK = Shutdown(
task_id=SHUTDOWN_TASK_ID,
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=4,
temperature=0.0,
)
CHAT_TASK = ChatCompletion(
task_id=CHAT_COMPLETION_TASK_ID,
command_id=COMMAND_1_ID,
task_params=CHAT_PARAMS,
instance_id=INSTANCE_1_ID,
)
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
for test_event, true_event in zip(test_events, true_events, strict=True):
test_event.event_id = true_event.event_id
assert test_event == true_event, f"{test_event} != {true_event}"
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,
finish_reason="stop",
),
)
assert_events_equal(
events,
[
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
),
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -0,0 +1 @@
# TODO:

View File

@@ -1,5 +1,4 @@
import socket
from ipaddress import ip_address
from anyio import create_task_group, to_thread
@@ -28,23 +27,13 @@ async def check_reachability(
out[target_node_id].add(target_ip)
async def check_reachable(our_node_id: NodeId, topology: Topology) -> dict[NodeId, set[str]]:
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
our_profile = topology.get_node_profile(our_node_id)
if our_profile is None:
return {}
our_interfaces = our_profile.network_interfaces
async with create_task_group() as tg:
for node in topology.list_nodes():
if node.node_id == our_node_id or node.node_profile is None:
if not node.node_profile:
continue
for iface in node.node_profile.network_interfaces:
if ip_address(iface.ip_address).is_loopback:
# Definitely a loopback address
continue
if iface in our_interfaces:
# Skip duplicates with our own interfaces
continue
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
)

18
uv.lock generated
View File

@@ -334,6 +334,7 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -374,7 +375,8 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.30.1" },
{ name = "mlx", marker = "sys_platform != 'linux'", specifier = ">=0.30.1" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "protobuf", specifier = ">=6.32.0" },
@@ -801,6 +803,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[package.optional-dependencies]
cpu = [
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx-cpu"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
]
[[package]]
name = "mlx-lm"
version = "0.28.3"