Compare commits

..

6 Commits

Author SHA1 Message Date
Evan
cc09ba01e4 yay 2026-01-26 18:19:00 +00:00
Alex Cheema
44453c4c8b Remove change-detection checks from info gatherer monitors (#1283)
## Summary
- When a node times out, its info gets cleared from state. The monitor
functions only sent data when something changed, leaving no mechanism to
re-populate this info after a timeout.
- Removes change-detection checks from `_monitor_misc`,
`_monitor_system_profiler_thunderbolt_data`, `_watch_system_info`, and
`_monitor_thunderbolt_bridge_status` so data is sent periodically
regardless of whether it changed.

## Test plan
- [ ] Verify type checker passes: `uv run basedpyright`
- [ ] Verify linter passes: `uv run ruff check`
- [ ] Verify tests pass: `uv run pytest`
- [ ] Manually test that node info is re-populated after a timeout by
observing cluster behavior

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 12:23:22 +00:00
Jake Hillion
1290e8ed9f dashboard: fix prettier-svelte rebuilding on every file change
The prettier-svelte package was rebuilding whenever any file in the
repository changed because dashboardStubSrc referenced inputs.self
directly. Since inputs.self's store path hash is computed from the
entire repository contents, any file modification invalidated the
derivation.

Added dashboardLockfileSrc using lib.cleanSourceWith to filter
inputs.self to only include package.json and package-lock.json from
the dashboard directory. Updated dashboardStubSrc to reference this
filtered source instead of inputs.self directly.

This ensures prettier-svelte only rebuilds when the lockfiles actually
change, significantly improving build caching for unrelated changes.

Test plan:
- Built prettier-svelte with nix build .#prettier-svelte
- Modified src/exo/main.py and rebuilt - same store path (no rebuild)
- Modified dashboard/package.json and rebuilt - different store path (rebuild triggered)
- Ran nix flake check successfully
2026-01-26 12:02:05 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
Alex Cheema
ff4a2022f7 Revert state compaction (#1259) (#1275)
## Summary

Reverts the state compaction feature (#1259) to investigate issues with
nodes staying as "unknown" after joining a cluster.

## Test plan

- [ ] Verify nodes properly show up after joining cluster
- [ ] Verify state catchup works correctly without compaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-01-23 16:29:48 -08:00
rltakashige
cee48f6f34 Parse GPT OSS tool calling (#1271)
## Motivation

<img width="3162" height="858" alt="image"
src="https://github.com/user-attachments/assets/e552f373-620a-4522-894b-6f93fd7f1e50"
/>

## Changes

OpenAI Harmony StreamableParser does parsing for us.

## Why It Works

<img width="3230" height="588" alt="image"
src="https://github.com/user-attachments/assets/81f8a43e-c04b-4bd0-9fd0-65e9b5f6ea1d"
/>
2026-01-23 20:43:53 +00:00
19 changed files with 266 additions and 243 deletions

View File

@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -3,12 +3,28 @@
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to ONLY include package.json and package-lock.json
# This ensures prettier-svelte only rebuilds when lockfiles change
dashboardLockfileSrc = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
isDashboardDir = baseName == "dashboard" && type == "directory";
isPackageFile =
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
&& (baseName == "package.json" || baseName == "package-lock.json");
in
isDashboardDir || isPackageFile;
};
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer

View File

@@ -53,7 +53,6 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.STATE_CATCHUP)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
@@ -83,7 +82,6 @@ class Node:
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
api = None
@@ -96,7 +94,6 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
@@ -110,7 +107,6 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
)
er_send, er_recv = channel[ElectionResult]()
@@ -193,7 +189,6 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
)
self._tg.start_soon(self.master.run)
elif (
@@ -240,9 +235,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
state_catchup_receiver=self.router.receiver(
topics.STATE_CATCHUP
),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),

View File

@@ -166,7 +166,6 @@ class API:
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
state_catchup_receiver: Receiver[State],
) -> None:
self.state = State()
self._event_log: list[Event] = []
@@ -174,7 +173,6 @@ class API:
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.state_catchup_receiver = state_catchup_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -1251,7 +1249,6 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._state_catchup)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1262,22 +1259,6 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _state_catchup(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"API catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -68,8 +68,6 @@ class Master:
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
state_catchup_sender: Sender[State],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -79,7 +77,6 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.state_catchup_sender = state_catchup_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -87,6 +84,7 @@ class Master:
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
async def run(self):
@@ -293,17 +291,11 @@ class Master:
command.finished_command_id
]
case RequestEventLog():
if command.since_idx == 0:
# This is an optimization, and should not be relied upon in theory.
logger.info(
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
await self.state_catchup_sender.send(self.state)
else:
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -27,7 +27,6 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -48,7 +47,6 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
st_s, _st_r = channel[State]()
all_events: list[IndexedEvent] = []
@@ -69,7 +67,6 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
state_catchup_sender=st_s,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -7,7 +7,6 @@ from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.shared.types.state import State
from exo.utils.pydantic_ext import CamelCaseModel
@@ -46,7 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -349,13 +349,8 @@ class InfoGatherer:
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await self.info_sender.send(await MiscData.gather())
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
@@ -365,15 +360,12 @@ class InfoGatherer:
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
@@ -398,22 +390,17 @@ class InfoGatherer:
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = await get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
if curr is not None:
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)

View File

@@ -145,10 +145,6 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -256,6 +252,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

View File

@@ -60,8 +60,9 @@ class Worker:
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
state_catchup_receiver: Receiver[State],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
@@ -70,8 +71,6 @@ class Worker:
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.state_catchup_receiver = state_catchup_receiver
self.local_event_index = 0
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
@@ -111,7 +110,6 @@ class Worker:
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._check_catchup_state)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
@@ -131,22 +129,6 @@ class Worker:
)
)
async def _check_catchup_state(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"Worker catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -336,7 +318,10 @@ class Worker:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
assert since_idx >= 0
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope

View File

@@ -7,6 +7,7 @@ from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.tests.patches import load_null_model
logger: "loguru.Logger" = loguru.logger
@@ -16,6 +17,8 @@ def entrypoint(
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
*,
_load_null_models: bool = False,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -29,6 +32,13 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
p = None
if _load_null_models:
from unittest.mock import patch
p = patch("mlx_lm.utils.load_model", new=load_null_model)
p.start()
global logger
logger = _logger
@@ -52,6 +62,8 @@ def entrypoint(
)
)
finally:
if p is not None:
p.stop()
try:
event_sender.close()
task_receiver.close()

View File

@@ -240,10 +240,6 @@ def main(
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
@@ -257,10 +253,16 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
@@ -489,9 +491,10 @@ def get_gpt_oss_encoding():
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
@@ -501,17 +504,44 @@ def filter_kimi_tokens(
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
]
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
@@ -528,13 +558,12 @@ def parse_gpt_oss(
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -542,6 +571,9 @@ def parse_thinking_models(
"""
first = True
for response in responses:
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
@@ -622,7 +654,7 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
@@ -630,6 +662,7 @@ def parse_tool_calls(
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True

View File

@@ -0,0 +1,50 @@
# type: ignore
import importlib
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from exo.worker.engines.mlx import Model
def load_null_model(path: Path, **_: object) -> "tuple[Model, dict[str, Any]]":
with open(path / "config.json", "r") as f:
cfg = json.load(f)
model, args = _get_classes(cfg)
model = model(args.from_dict(cfg))
return model, cfg
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
raise ValueError(msg) from None
return arch.Model, arch.ModelArgs
MODEL_REMAPPING = {
"mistral": "llama",
"llava": "mistral3",
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
"kimi_k2": "deepseek_v3",
"qwen2_5_vl": "qwen2_vl",
"minimax_m2": "minimax",
"iquestcoder": "llama",
}

View File

@@ -154,7 +154,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
tasks={},
)
assert result is None
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():

View File

@@ -1,7 +1,6 @@
import multiprocessing as mp
import socket
import time
import typing
import anyio
from fastapi import FastAPI
@@ -11,16 +10,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,18 +31,17 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import RunnerFailed, RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
MODEL_CARDS = {"haha": MODEL_CARDS["qwen3-coder-480b-a35b-8bit"]}
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
mp.set_start_method("spawn", force=True)
@@ -56,16 +50,14 @@ logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:8000"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/run_test")(run_test)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
@@ -87,28 +79,7 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):
async def run_test(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
@@ -117,10 +88,30 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
for card in MODEL_CARDS.values():
for instance in (
ring_instance(test, card.model_id, iid, hn),
jaccl_instance(test, card.model_id, iid),
):
recv = await execute_test(test, instance, hn)
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, RunnerFailed
):
return
except anyio.ClosedResourceError:
pass
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
def ring_instance(test: Tests, model_id: ModelId, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -135,13 +126,13 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,7 +154,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
@@ -171,60 +162,33 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
for card in MODEL_CARDS.values():
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=card.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
return recv
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests, model_id: ModelId, iid: InstanceId):
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
world_size = len(test.devs)
return MlxJacclInstance(
@@ -235,7 +199,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
@@ -270,6 +234,7 @@ def new_runner(
task_recv,
logger,
),
kwargs={"_load_null_models": True},
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()

View File

@@ -6,19 +6,8 @@ query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
if [[ $# -lt 1 ]]; then
echo "USAGE: $0 [host1] [host2] ..."
exit 1
fi
@@ -34,23 +23,12 @@ done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
for i in "${!ips[@]}"; do
{
curl -sN \
-X POST "http://${ips[$i]}:8000/run_test" \
-H "Content-Type: application/json" -d "{\"devs\": ${devs}}" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait

40
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -412,8 +412,8 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -994,8 +994,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1022,12 +1022,18 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1040,14 +1046,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1078,7 +1076,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1086,6 +1084,16 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"