mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-28 00:29:48 -05:00
Compare commits
1 Commits
main
...
dont-disco
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb16352c85 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,8 +7,6 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
|
||||
60
README.md
60
README.md
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
### Run from Source (Mac & Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,62 +98,6 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
|
||||
@@ -29,8 +29,7 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx>=0.30.1",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event), state
|
||||
)
|
||||
|
||||
assert new_state.downloads == {NodeId("node-1"): [event]}
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
||||
|
||||
|
||||
def test_apply_two_node_download_progress():
|
||||
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
||||
|
||||
@@ -40,10 +40,6 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -61,11 +57,5 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
)
|
||||
|
||||
@@ -21,15 +21,7 @@ class BaseRunnerStatus(TaggedModel):
|
||||
return isinstance(self, RunnerRunning)
|
||||
|
||||
|
||||
class RunnerIdle(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnecting(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnected(BaseRunnerStatus):
|
||||
class RunnerWaitingForModel(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
@@ -62,9 +54,7 @@ class RunnerFailed(BaseRunnerStatus):
|
||||
|
||||
|
||||
RunnerStatus = (
|
||||
RunnerIdle
|
||||
| RunnerConnecting
|
||||
| RunnerConnected
|
||||
RunnerWaitingForModel
|
||||
| RunnerLoading
|
||||
| RunnerLoaded
|
||||
| RunnerWarmingUp
|
||||
|
||||
@@ -10,6 +10,7 @@ KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
TEMPERATURE: float = 1.0
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -13,6 +13,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -20,8 +21,6 @@ try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
@@ -49,7 +48,6 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
# Needed for 8 bit model
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
@@ -69,7 +67,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
@@ -79,7 +77,7 @@ def mx_barrier(group: Group | None = None):
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
@@ -101,96 +99,91 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> Group:
|
||||
) -> mx.distributed.Group:
|
||||
"""
|
||||
Initialize MLX distributed.
|
||||
Initialize the MLX distributed (runs in thread pool).
|
||||
|
||||
Either hosts or mlx_ibv_devices must be provided:
|
||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
|
||||
logger.info(
|
||||
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
return group
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
||||
"""
|
||||
mx.random.seed(42)
|
||||
|
||||
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
|
||||
"Tried to initialize mlx for a single node instance"
|
||||
)
|
||||
return mlx_distributed_init(bound_instance)
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if group is None:
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
@@ -200,12 +193,14 @@ def load_mlx_items(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
group: mx.distributed.Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
self.event_sender.send_nowait(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
@@ -414,7 +414,7 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology)
|
||||
conns = await check_reachable(self.node_id, self.state.topology)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
edge = Connection(
|
||||
|
||||
@@ -5,7 +5,6 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
@@ -15,23 +14,17 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -55,7 +48,6 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
@@ -114,11 +106,9 @@ def _model_needs_download(
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
)
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
@@ -127,54 +117,14 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
""" --- TODO!
|
||||
def _init_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
):
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance:
|
||||
continue
|
||||
|
||||
runner_is_idle = isinstance(runner.status, RunnerIdle)
|
||||
all_runners_connecting = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerConnecting, RunnerIdle),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if not (runner_is_idle and all_runners_connecting):
|
||||
continue
|
||||
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
world_size = shard.world_size
|
||||
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
accepting_ranks = device_rank < world_size - 1
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
|
||||
if not (accepting_ranks or connecting_rank_ready):
|
||||
continue
|
||||
|
||||
return ConnectToGroup(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
) -> LoadModel | None:
|
||||
for runner in runner.values()
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
def _load_model(
|
||||
@@ -186,33 +136,31 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
all_downloads_complete_local = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
for nid, rid in shard_assignments.node_to_runner.items()
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
|
||||
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
all_ready_for_model = all(
|
||||
all_runners_expecting_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
if (
|
||||
all_downloads_complete_local
|
||||
and runner_is_waiting
|
||||
and all_runners_expecting_model
|
||||
):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
@@ -235,9 +183,8 @@ def _ready_to_warmup(
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
# TODO: Ensure these align with MLX distributeds expectations.
|
||||
# Rank != 0
|
||||
accepting_ranks_ready = device_rank > 0 and all(
|
||||
# Rank != n-1
|
||||
accepting_ranks_ready = device_rank != world_size - 1 and all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerLoaded, RunnerWarmingUp),
|
||||
@@ -245,8 +192,8 @@ def _ready_to_warmup(
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
# Rank = 0
|
||||
connecting_rank_ready = device_rank == 0 and all(
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
@@ -274,8 +221,6 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -11,7 +11,6 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -23,23 +22,20 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -67,10 +63,9 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
current_status: RunnerStatus = RunnerWaitingForModel()
|
||||
logger.info("runner waiting for model")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -83,26 +78,9 @@ def main(
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
case LoadModel() if isinstance(
|
||||
current_status, (RunnerWaitingForModel, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
@@ -111,12 +89,15 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
@@ -142,6 +123,11 @@ def main(
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
@@ -186,6 +172,11 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
@@ -195,19 +186,12 @@ def main(
|
||||
)
|
||||
break
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
raise ValueError("Received task outside of state machine")
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
|
||||
@@ -19,8 +19,8 @@ from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -41,7 +41,7 @@ class RunnerSupervisor:
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -24,9 +24,3 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
||||
|
||||
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
||||
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
||||
|
||||
SHUTDOWN_TASK_ID = TaskId("shutdown")
|
||||
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
|
||||
INITIALIZATION_TASK_ID = TaskId("initialisation")
|
||||
LOAD_TASK_ID = TaskId("load")
|
||||
WARMUP_TASK_ID = TaskId("warmup")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
@@ -16,7 +14,6 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
# Runner supervisor without multiprocessing logic.
|
||||
@dataclass(frozen=True)
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
@@ -38,8 +35,6 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
# hidden_size=2048,
|
||||
# supports_tensor=False,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
@@ -74,24 +69,3 @@ def get_mlx_ring_instance(
|
||||
),
|
||||
hosts=[],
|
||||
)
|
||||
|
||||
|
||||
def get_bound_mlx_ring_instance(
|
||||
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
|
||||
) -> BoundInstance:
|
||||
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
|
||||
other_shard = get_pipeline_shard_metadata(
|
||||
model_id=model_id, device_rank=1, world_size=2
|
||||
)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=instance_id,
|
||||
model_id=model_id,
|
||||
node_to_runner={
|
||||
node_id: runner_id,
|
||||
NodeId("other_node"): RunnerId("other_runner"),
|
||||
},
|
||||
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
|
||||
)
|
||||
return BoundInstance(
|
||||
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||
)
|
||||
|
||||
@@ -4,8 +4,7 @@ from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -39,11 +38,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
@@ -81,15 +82,15 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
@@ -132,11 +133,13 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
@@ -180,14 +183,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
@@ -210,22 +213,3 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerIdle(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
|
||||
@@ -2,8 +2,8 @@ import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import StartWarmup
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -128,7 +128,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerIdle(),
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
COMMAND_1_ID,
|
||||
INITIALIZATION_TASK_ID,
|
||||
INSTANCE_1_ID,
|
||||
LOAD_TASK_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
SHUTDOWN_TASK_ID,
|
||||
WARMUP_TASK_ID,
|
||||
)
|
||||
from ..conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
nothin = make_nothin(None)
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(
|
||||
task_id=INITIALIZATION_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
LOAD_TASK = LoadModel(
|
||||
task_id=LOAD_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
WARMUP_TASK = StartWarmup(
|
||||
task_id=WARMUP_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
SHUTDOWN_TASK = Shutdown(
|
||||
task_id=SHUTDOWN_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
CHAT_PARAMS = ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=4,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
CHAT_TASK = ChatCompletion(
|
||||
task_id=CHAT_COMPLETION_TASK_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=CHAT_PARAMS,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
|
||||
for test_event, true_event in zip(test_events, true_events, strict=True):
|
||||
test_event.event_id = true_event.event_id
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NODE_A,
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
with task_sender, event_receiver:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
|
||||
# worst monkeypatch known to man
|
||||
# this is some c++ nonsense
|
||||
event_sender.close = nothin
|
||||
event_sender.join = nothin
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
assert_events_equal(
|
||||
events,
|
||||
[
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
# TODO:
|
||||
@@ -1,4 +1,5 @@
|
||||
import socket
|
||||
from ipaddress import ip_address
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
|
||||
@@ -27,13 +28,23 @@ async def check_reachability(
|
||||
out[target_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
||||
async def check_reachable(our_node_id: NodeId, topology: Topology) -> dict[NodeId, set[str]]:
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
our_profile = topology.get_node_profile(our_node_id)
|
||||
if our_profile is None:
|
||||
return {}
|
||||
our_interfaces = our_profile.network_interfaces
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
if node.node_id == our_node_id or node.node_profile is None:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
if ip_address(iface.ip_address).is_loopback:
|
||||
# Definitely a loopback address
|
||||
continue
|
||||
if iface in our_interfaces:
|
||||
# Skip duplicates with our own interfaces
|
||||
continue
|
||||
tg.start_soon(
|
||||
check_reachability, iface.ip_address, node.node_id, reachable
|
||||
)
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -334,7 +334,6 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -375,8 +374,7 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", marker = "sys_platform != 'linux'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", specifier = ">=0.30.1" },
|
||||
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "protobuf", specifier = ">=6.32.0" },
|
||||
@@ -803,20 +801,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cpu = [
|
||||
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.28.3"
|
||||
|
||||
Reference in New Issue
Block a user