Compare commits

..

45 Commits

Author SHA1 Message Date
Alex Cheema
9cacef39f0 feat: add --bootstrap-peer flag to bypass mDNS for peer discovery
macOS TCC blocks mDNS multicast from SSH sessions, preventing
SSH-spawned exo processes from discovering peers. This adds a
--bootstrap-peer IP:PORT flag that directly dials a known peer,
bypassing mDNS. The flag is additive — mDNS still runs alongside.

Changes across Rust networking, PyO3 bindings, and Python CLI:
- discovery.rs: store bootstrap addresses, dial on startup and retry
- swarm.rs: thread bootstrap_peers through to Behaviour
- networking.rs: accept optional bootstrap_peers in PyO3 constructor
- router.py: pass bootstrap_peers to NetworkingHandle
- main.py: parse --bootstrap-peer CLI arg into multiaddr format

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:24:00 -08:00
Alex Cheema
3df896dec1 fix: cancel active tasks on meta-instance cascade delete
Previously, DeleteMetaInstance cascade-deleted backing instances without
cancelling their active tasks, leaving orphaned task references. Now emits
TaskStatusUpdated(Cancelled) for Pending/Running tasks before InstanceDeleted.

Also adds lifecycle logging for meta-instance operations, a GET /meta_instances
endpoint, and 2 regression tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 09:51:55 -08:00
Alex Cheema
29c3489f3e Merge remote-tracking branch 'origin/main' into alexcheema/mlx-distributed-transfer
Resolve merge conflicts between meta-instance feature and task cancellation:
- commands.py: Include both CreateMetaInstance/DeleteMetaInstance and TaskCancelled
- plan.py: Add _cancel_tasks from main, keep all_runners param in _create_runner
- runner_supervisor.py: Merge cancel_sender/cancelled fields with pipe relay fields
- reconcile.py: Pass empty tasks dict to get_transition_events (new param from main)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:38:47 -08:00
Alex Cheema
19f606a379 fix: dashboard derived reactivity bug and tautological check in meta-instance UI
1. DERIVED REACTIVITY BUG: `unifiedDisplayItems` used `$derived(fn)` which
   made the derived value the function itself instead of its result. Svelte
   never tracked reactive dependencies in the function body, so the instance
   list didn't update when metaInstances or instances changed. Fixed by using
   `$derived.by(fn)` and removing the `()` call-sites in the template.

2. TAUTOLOGICAL CHECK: In `getMetaInstancePlacingStatus`, the `lastError ? ...
   : null` guard inside the `failures > 0` branch was always true because
   `lastFailureError` and `consecutiveFailures` are always set together in
   `apply_instance_retrying` and `apply_instance_deleted`. Removed the dead
   `: null` branch.

Also fixes pyright errors in test file by using proper pytest.MonkeyPatch type.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:24:49 -08:00
Alex Cheema
c74b46dca0 fix: add timeout and error handling for ModelCard.load in MetaInstanceReconciler
ModelCard.load() does async I/O inside the 1-second reconcile loop. A slow
or failing load blocked all reconciliation (health checks, node timeouts,
other meta-instances). Adds a 10-second timeout, per-meta-instance error
handling with MetaInstancePlacementFailed events, and documents the
intentional early return in apply_instance_retrying.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:11:29 -08:00
Alex Cheema
7028f99d7e fix: validate placement before sending command to prevent silent failures (BUG-001c)
The place_instance API endpoint used fire-and-forget: it sent the command
and returned HTTP 200 immediately. On a fresh cluster start, the master's
state often lacks topology/memory data, so placement raises ValueError
which was silently caught and logged. The caller never learned it failed.

Two fixes:
- API: validate placement locally before sending, return HTTP 400 on
  failure instead of silently accepting an unprocessable command
- Master: emit MetaInstancePlacementFailed on immediate placement error
  in CreateMetaInstance handler so the error surfaces in state right away

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:13:38 -08:00
Alex Cheema
68d9debcb6 fix: place_instance now uses all available nodes to prevent OOM (BUG-001d)
The placement algorithm previously selected the smallest viable cycle,
causing large models to be distributed across too few nodes and running
out of memory. Changed get_smallest_cycles to get_largest_cycles so that
all healthy nodes are utilized, spreading layers more evenly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:13:03 -08:00
Alex Cheema
9bf5f86b9e Revert "feat: add task cancellation for client disconnect handling (BUG-001)"
This reverts commit 2a75672a4a.
2026-02-15 13:48:01 -08:00
Alex Cheema
2a75672a4a feat: add task cancellation for client disconnect handling (BUG-001)
- Add TaskCancelled command and Cancelled task status
- Detect API client disconnects in master/api.py
- Handle TaskCancelled in master state machine
- Add _cancel_tasks to worker for graceful task cleanup
- Add cancel_receiver to runner for inference abort
- Add mx_any helper in MLX utils for cancellable operations
- Guard instance lookup in worker to prevent KeyError
- Update tests for cancellation flow

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 13:43:15 -08:00
Alex Cheema
8a6344fba3 Fix model discovery reporting DownloadPending for fully-downloaded models
On startup, _emit_existing_download_progress() used
downloaded_bytes_this_session to decide between DownloadPending and
DownloadOngoing. Since downloaded_bytes_this_session is always 0 on
startup (it tracks the current session only), fully-downloaded models
were incorrectly reported as DownloadPending.

Now checks actual disk state: if downloaded_bytes >= total_bytes, emit
DownloadCompleted regardless of session bytes. This fixes the UI showing
models as pending when they're already available.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 09:39:06 -08:00
Alex Cheema
316a7a344e fix: use force=True for multiprocessing set_start_method
Prevents RuntimeError when the context has already been set,
e.g. when Terminal.app reuses a tab or the process restarts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 07:02:32 -08:00
Alex Cheema
16b58d6946 fix: eliminate command/reconciler interleaving race in meta-instance
Two race conditions existed in the meta-instance lifecycle:

1. CreateMetaInstance buffered MetaInstanceCreated but didn't apply it
   before awaiting ModelCard.load(). The reconciler could interleave
   during the await, leading to duplicate placements.

   Fix: apply MetaInstanceCreated eagerly via _apply_and_broadcast,
   then re-check satisfaction after the await so placement uses fresh
   state and skips if the reconciler already handled it.

2. delete_meta_instance (API handler) sent DeleteMetaInstance then
   read self.state.instances for cascade deletion. State was stale,
   so backing instances created between the send and the read were
   missed — permanently orphaning them.

   Fix: move cascade delete into the command processor's
   DeleteMetaInstance handler, where InstanceDeleted events are
   generated atomically with MetaInstanceDeleted.

Reproduced on 4-node Mac Mini cluster: 28K anomalies in stress test
including 21 permanently orphaned instances. After fix, the cascade
delete and placement are race-free.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:50:57 -08:00
Alex Cheema
39741907c7 test: add 25 edge-case tests for MetaInstance lifecycle
Cover retry logic, error handling, backward compatibility,
concurrent scenarios, placement error tracking, and serialization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:17:41 -08:00
Alex Cheema
e5c73c564c Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance 2026-02-13 10:10:38 -08:00
Alex Cheema
77799a170a Fix JACCL SideChannel bytes serialization for JSON round-trip
TaggedModel's wrap validator converts JSON→Python validation context,
which breaks strict-mode bytes deserialization from JSON strings.
Use Base64Bytes type to encode/decode bytes as base64 strings in JSON.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 14:23:44 -08:00
Alex Cheema
5dddd9cd2b Use named pipes (FIFOs) for JACCL SideChannel relay
Anonymous pipes from os.pipe() don't survive multiprocessing.Process
spawn on macOS (default since Python 3.8). The FD numbers are passed
but the actual file descriptors don't exist in the child process,
causing EBADF errors.

Switch to named pipes (FIFOs) which the child opens by path in the
spawned process, getting valid FDs for the C++ SideChannel.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 12:06:37 -08:00
Alex Cheema
0e08c2bfd3 Add pipe-based JACCL SideChannel relay via exo control plane
Replace fragile TCP SideChannel with anonymous pipes relayed through
exo's event-sourced control plane. RunnerSupervisor creates pipe pairs
for MlxJaccl instances, relays all_gather rounds via JacclSideChannelData/
JacclSideChannelGathered events through the master, eliminating errno=57
crashes from Thunderbolt RDMA driver instability.

Also includes dashboard RDMA warning improvements and instance retry fixes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 08:55:40 -08:00
Alex Cheema
f9ffdaef5f Preserve last_failure_error across instance recreation, fix RDMA banner wording
- apply_instance_created no longer clears last_failure_error so the
  error context persists while the new instance starts up
- Dashboard retryError shows the error without (N/3) prefix when
  consecutiveFailures is 0 (instance was recreated)
- Jaccl warning tooltip now says "experimental RDMA driver in macOS"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:48:30 -08:00
Alex Cheema
8c2416c9ea chore: remove temporary screenshot files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:40:16 -08:00
Alex Cheema
e5007f619a temp: add jaccl warning screenshots for PR comment
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:53 -08:00
Alex Cheema
a627f67253 dashboard: show warning banner for [jaccl] RDMA driver errors
Detect errors containing "[jaccl]" in MetaInstance failure errors and
display a red dismissible alert banner. The tooltip explains this is a
macOS RDMA driver issue and that the affected machine needs to be
restarted. Alert re-appears if a new error arrives after dismissal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:42 -08:00
Alex Cheema
f189222bfc Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance
# Conflicts:
#	dashboard/src/lib/stores/app.svelte.ts
#	dashboard/src/routes/+page.svelte
2026-02-11 15:59:50 -08:00
Alex Cheema
ad6d35d68a Retry runners within the same Instance instead of recreating
When runners fail for a MetaInstance-backed Instance, retry up to 3
times by restarting runners within the same Instance rather than
deleting and recreating it each time. After 3 failures, delete the
Instance so MetaInstanceReconciler can create a fresh one.

- Add InstanceRetrying event that removes runners from state (signaling
  workers to restart) and increments consecutive_failures on MetaInstance
- InstanceHealthReconciler emits InstanceRetrying when under retry limit,
  InstanceDeleted when exhausted or no MetaInstance
- Worker _kill_runner detects retry signal (runner deleted from state +
  terminal supervisor) and cleans up for _create_runner to recreate
- Worker _create_runner guards against oscillation by blocking creation
  while any peer runner has explicit terminal status
- InstanceCreated resets consecutive_failures for fresh starts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 14:21:11 -08:00
Alex Cheema
c236d62caf Remove timestamp-based retry cooldown
Remove last_failure_at field and RETRY_COOLDOWN_SECONDS logic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:59:39 -08:00
Alex Cheema
a8069e8a30 Consolidate failure state onto MetaInstance, add 5s retry cooldown
Move placement_error, consecutive_failures, last_failure_error, and
last_failure_at directly onto the MetaInstance model instead of keeping
them as separate State mappings (meta_instance_errors, InstanceFailureInfo,
meta_instance_failure_info). Adds a 5-second cooldown between retry attempts
to prevent rapid instance churn when runners fail instantly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:55:47 -08:00
Alex Cheema
84ce555d55 Show retry attempt count with error message, e.g. (2/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:43:20 -08:00
Alex Cheema
b78ea438bc Include node friendly names in runner error messages
Each error in the combined message is now prefixed with the node's friendly
name (e.g. "MacBook Pro: OOM; Mac Studio: connection reset") so the root
cause node is easily identifiable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:41:10 -08:00
Alex Cheema
1960b16f9f Remove permanent retry blocking, allow continuous retry batches
The dashboard % 3 logic already handles displaying retry progress in batches
(RETRYING 1/3, 2/3, 3/3, then PLACING with error, repeat). No need to
permanently block placement after 3 failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:35:03 -08:00
Alex Cheema
c6838c8fd8 Show retry count in exceeded retry limit message (3/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:28:17 -08:00
Alex Cheema
420d9b9e76 Collect all runner error messages instead of just the last one
When multiple runners fail, concatenate all error messages with "; " so the
real error isn't hidden by generic side-effect failures from other runners.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:27:49 -08:00
Alex Cheema
13f1e9c489 Stop infinite retries after 3 failures, show errors persistently in dashboard
MetaInstanceReconciler now checks failure count before placement — after 3
consecutive failures it emits MetaInstancePlacementFailed instead of retrying
forever. Dashboard shows "Retrying after error: <msg>" in orange throughout
the retry cycle, not just during the brief window with no backing instance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:21:11 -08:00
Alex Cheema
451a06b3d8 Add instance retry logic with max 3 retries and failure tracking
- Extend InstanceDeleted with failure_error field for runner crash info
- Add InstanceFailureInfo model tracking consecutive failures per MetaInstance
- InstanceHealthReconciler now detects runner failures (all terminal with
  at least one RunnerFailed) in addition to connection failures
- apply_instance_deleted increments failure counter for meta-bound instances
- Dashboard shows RETRYING (N/3) status with error messages, and
  "Instance re-created due to failure" after 3 consecutive failures
- Extract and display RunnerFailed error messages in instance status

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:09:42 -08:00
Alex Cheema
94b55d66f4 Fix MetaInstance.node_ids frozenset failing JSON deserialization
frozenset serializes to a JSON array but cannot be deserialized back
in strict mode through the TaggedModel wrap validator (list → frozenset
coercion is rejected). Changed to list[NodeId] since the model is
already frozen/immutable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:54:56 -08:00
Alex Cheema
2b68b931c5 Send node_ids from placement preview when launching instances
The dashboard now extracts node IDs from the selected preview's
memory_delta_by_node, ensuring the backend places on exactly the
nodes the user was shown. Also reverts incorrect RDMA min_nodes >= 2
enforcement since single-node RDMA is valid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:49:37 -08:00
Alex Cheema
4aecaa7748 Enforce min_nodes >= 2 for RDMA (MlxJaccl) instances
RDMA requires at least 2 nodes — a single-node RDMA instance is
nonsensical. Enforce this in both the dashboard (when building the
launch request) and the backend placement (when filtering cycles).
Previously, selecting RDMA would still place on 1 node because
min_nodes defaulted to 1 and the placement silently switched to Ring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:42:21 -08:00
Alex Cheema
25e2891c30 Ensure min_nodes >= node filter size when launching
When user selects specific nodes via the filter, min_nodes should be at
least the number of filtered nodes to prevent placement from picking a
smaller cycle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:36:51 -08:00
Alex Cheema
16345e0ffa Send node_ids from dashboard, error on RDMA when unavailable
Dashboard was not including the user's node filter in the POST to
/meta_instance, so placement ignored which nodes the user selected.
Also, placement silently fell back to Ring when RDMA was requested but
no RDMA-connected cycles were available — now raises an error that
surfaces via MetaInstancePlacementFailed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:26:29 -08:00
Alex Cheema
3a845f90b0 Fix use_default validator silently ignoring sharding/instance_meta
The mode="plain" validator bypassed Pydantic's string-to-enum coercion,
so JSON strings like "Tensor" and "MlxJaccl" from the dashboard failed
the isinstance check and silently fell back to Pipeline/MlxRing defaults.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:05:00 -08:00
Alex Cheema
dccf2440ba Add placement error feedback and per-node loading status
Show why MetaInstance placement fails instead of stuck "PLACING", and
show per-node runner status during loading for multi-node instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:01:07 -08:00
Alex Cheema
f96f3f2c0f Show MetaInstance sharding/type while PLACING, fix MlxIbv references
When a MetaInstance has no backing instance yet, derive the strategy
display from the MetaInstance's own sharding and instanceMeta fields
rather than showing "Unknown (Unknown)".

Also clean up all stale MlxIbv references across the dashboard —
the backend enum is MlxJaccl.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:23:44 -08:00
Alex Cheema
7d54e468d5 Extract reconciler into ProcessManager protocol, fix RDMA instance type
- Replace inline _plan() with ProcessManager loop (_reconcile), tick
  every 1s instead of 10s — safe because all PMs are idempotent
- Fix dashboard sending "MlxIbv" instead of "MlxJaccl" for RDMA
  instance type, which silently fell back to MlxRing default
- Remove all stale MlxIbv references from dashboard

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:19:13 -08:00
Alex Cheema
124d504f95 Extract reconciler into ProcessManager protocol
Replace inline _plan() steps with a list of ProcessManagers, each
implementing async reconcile(State) -> Sequence[Event]. Tick every
1s instead of 10s — safe because all PMs are idempotent against state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:06:41 -08:00
Alex Cheema
9ab4a40989 Simplify MetaInstance binding: put meta_instance_id on Instance
The separate MetaInstanceBound event + meta_instance_backing map
introduced two bugs: stale exclusion sets in the reconciler loop and
a delete ordering race. Embedding meta_instance_id directly on
BaseInstance eliminates the binding mechanism entirely — when an
instance is created for a MetaInstance it carries the ID, when
deleted the binding is gone. No separate map, no cleanup, no races.

Also fixes delete_meta_instance to cascade-delete backing instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 16:15:29 -08:00
Alex Cheema
f4329c72c2 Add explicit MetaInstance binding, slim MetaInstance to use ModelId
- Add MetaInstanceBound event and meta_instance_backing State field
  for explicit MetaInstance → Instance binding (prevents ambiguous
  linking when two MetaInstances have identical constraints)
- Replace model_card: ModelCard with model_id: ModelId on MetaInstance
  (load ModelCard on-demand at placement time)
- Add MetaInstance API endpoints (POST /meta_instance, DELETE)
- Update dashboard to use MetaInstances as primary primitive with
  unified display items merging MetaInstances and orphan instances
- Dashboard launches via MetaInstance instead of direct Instance creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 15:53:07 -08:00
Alex Cheema
ceb76b8f6c Add MetaInstance declarative layer with connection health checking
Introduces MetaInstance as a declarative constraint ensuring an instance
matching given parameters (model, sharding, min_nodes) always exists.
The master's reconciliation loop continuously checks for unsatisfied
meta-instances and attempts placement. Connection health checking
verifies that specific IPs (MlxRing) and RDMA interfaces (MlxJaccl)
stored on instances still exist as topology edges, enabling automatic
recovery when cables are swapped or interfaces change.

Also eliminates the master's loopback event path, unifying all event
emission through _apply_and_broadcast for simpler control flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 13:43:50 -08:00
45 changed files with 908 additions and 1170 deletions

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {

View File

@@ -1,7 +0,0 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

View File

@@ -1,189 +0,0 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

@@ -549,8 +549,8 @@ class AppStore {
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
metaInstances = $state<Record<string, MetaInstanceData>>({});
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
Record<
@@ -1274,6 +1274,8 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
if (data.nodeDisk) {
this.nodeDisk = data.nodeDisk;
}
@@ -1283,8 +1285,6 @@ class AppStore {
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
// RDMA ctl status per node
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node

View File

@@ -42,9 +42,9 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
metaInstances,
nodeThunderbolt,
nodeRdmaCtl,
metaInstances,
thunderboltBridgeCycles,
nodeThunderboltBridge,
nodeIdentities,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-4bit"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 418621403136

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 767273926656

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM 5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405480321024

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -130,7 +130,7 @@ class Multiaddr:
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
def __new__(cls, identity: Keypair, bootstrap_peers: builtins.list[builtins.str] = ...) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.

View File

@@ -342,8 +342,10 @@ impl PyNetworkingHandle {
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
#[pyo3(signature = (identity, bootstrap_peers=vec![]))]
fn py_new(identity: Bound<'_, PyKeypair>, bootstrap_peers: Vec<String>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
use std::str::FromStr;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
@@ -353,9 +355,20 @@ impl PyNetworkingHandle {
// get identity
let identity = identity.borrow().0.clone();
// parse bootstrap peer multiaddrs
let parsed_peers: Vec<libp2p::Multiaddr> = bootstrap_peers
.into_iter()
.map(|s| libp2p::Multiaddr::from_str(&s))
.collect::<Result<_, _>>()
.pyerr()?;
if !parsed_peers.is_empty() {
log::info!("RUST: bootstrap peers: {:?}", parsed_peers);
}
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.block_on(async { create_swarm(identity, parsed_peers) })
.pyerr()?;
// spawn tokio task running the networking logic

View File

@@ -106,6 +106,9 @@ pub struct Behaviour {
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
/// Addresses provided via --bootstrap-peer, re-dialed on the retry loop.
bootstrap_addrs: Vec<Multiaddr>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
@@ -113,13 +116,21 @@ pub struct Behaviour {
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
pub fn new(keypair: &identity::Keypair, bootstrap_peers: Vec<Multiaddr>) -> io::Result<Self> {
let mut behaviour = Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
bootstrap_addrs: bootstrap_peers,
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
};
// Immediately dial all bootstrap peers
for addr in &behaviour.bootstrap_addrs.clone() {
behaviour.dial_addr(addr.clone());
}
Ok(behaviour)
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
@@ -128,6 +139,14 @@ impl Behaviour {
})
}
/// Dial by address only — PeerId is resolved via Noise handshake on connect.
/// Used for bootstrap peers where only IP:port is known.
fn dial_addr(&mut self, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::unknown_peer_id().address(addr).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
@@ -362,13 +381,16 @@ impl NetworkBehaviour for Behaviour {
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
// retry connecting to all mDNS + bootstrap peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
for addr in self.bootstrap_addrs.clone() {
self.dial_addr(addr)
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}

View File

@@ -1,7 +1,7 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
use libp2p::{Multiaddr, SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>;
@@ -16,11 +16,14 @@ pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
pub fn create_swarm(
keypair: identity::Keypair,
bootstrap_peers: Vec<Multiaddr>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.with_behaviour(|kp| Behaviour::new(kp, bootstrap_peers))?
.build();
// Listen on all interfaces and whatever port the OS assigns
@@ -105,7 +108,7 @@ mod transport {
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
use libp2p::{Multiaddr, gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
@@ -116,9 +119,12 @@ mod behaviour {
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
pub fn new(
keypair: &identity::Keypair,
bootstrap_peers: Vec<Multiaddr>,
) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
discovery: discovery::Behaviour::new(keypair, bootstrap_peers)?,
gossipsub: gossipsub_behaviour(keypair),
})
}

View File

@@ -14,7 +14,6 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -64,9 +63,6 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -78,7 +74,6 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -98,7 +93,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -176,11 +170,7 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -194,7 +184,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -217,7 +206,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -231,7 +219,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -266,7 +253,6 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -309,9 +295,6 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if (
@@ -326,11 +309,7 @@ class DownloadCoordinator:
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
@@ -339,9 +318,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -46,7 +46,12 @@ class Node:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
bootstrap_peers: list[str] = []
if args.bootstrap_peer:
ip, port = args.bootstrap_peer.rsplit(":", 1)
bootstrap_peers.append(f"/ip4/{ip}/tcp/{port}")
logger.info(f"Bootstrap peer: {args.bootstrap_peer}")
router = Router.create(keypair, bootstrap_peers=bootstrap_peers)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)
@@ -136,8 +141,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -149,6 +152,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -283,6 +288,7 @@ class Args(CamelCaseModel):
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
bootstrap_peer: str | None = None
@classmethod
def parse(cls) -> Self:
@@ -343,6 +349,13 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
parser.add_argument(
"--bootstrap-peer",
type=str,
dest="bootstrap_peer",
default=None,
help="IP:PORT of an existing node to connect to (bypasses mDNS)",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -144,8 +144,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=tool.id,
call_id=tool.id,
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
name=tool.name,
arguments=tool.arguments,
)

View File

@@ -603,10 +603,10 @@ class API:
break
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -946,10 +946,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1032,10 +1032,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:

View File

@@ -339,7 +339,6 @@ class Master:
self.state.instances,
self.state.node_memory,
self.state.node_network,
self.state.tasks,
)
generated_events.extend(result.events)
if result.error is not None:
@@ -417,19 +416,16 @@ class Master:
)
case TaskCancelled():
if (
command.cancelled_command_id
in self.command_task_mapping
):
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.cancelled_command_id
]
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
del self.command_task_mapping[
command.cancelled_command_id
]
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -438,10 +434,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time

View File

@@ -61,7 +61,6 @@ class MetaInstanceReconciler:
current_instances,
state.node_memory,
state.node_network,
state.tasks,
)
# Update local instance map so next placement sees this one
for event in result.events:

View File

@@ -11,7 +11,6 @@ from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
@@ -200,7 +199,6 @@ def try_place_for_meta_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
tasks: Mapping[TaskId, Task],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
@@ -232,9 +230,7 @@ def try_place_for_meta_instance(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(
get_transition_events(current_instances, target_instances, tasks)
),
events=list(get_transition_events(current_instances, target_instances, {})),
error=None,
)
except ValueError as e:

View File

@@ -101,8 +101,12 @@ class TopicRouter[T: CamelCaseModel]:
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity))
def create(
cls,
identity: Keypair,
bootstrap_peers: list[str] | None = None,
) -> "Router":
return cls(handle=NetworkingHandle(identity, bootstrap_peers or []))
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}

View File

@@ -338,6 +338,11 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -378,6 +383,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -182,7 +182,6 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -49,10 +49,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
@@ -61,6 +57,10 @@ class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -102,9 +102,9 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| CreateMetaInstance
| DeleteMetaInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId

View File

@@ -26,7 +26,6 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -1,7 +1,5 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -29,4 +27,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner, file=sys.stderr)
print(banner)

View File

@@ -125,7 +125,9 @@ class MpSender[T]:
self._state.buffer.put(item, block=True)
async def send_async(self, item: T) -> None:
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
await to_thread.run_sync(
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
logger.info("Ready to prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(

View File

@@ -285,7 +285,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
@@ -353,13 +353,7 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
@@ -574,11 +568,6 @@ def mlx_cleanup(
def mx_any(bool_: bool, group: Group | None) -> bool:
"""Synchronize a boolean across all distributed nodes.
Returns True if any node has bool_=True. Uses all_sum so every
node participates in the collective — preventing GPU deadlocks.
"""
if group is None:
return bool_
num_true = mx.distributed.all_sum(
@@ -596,41 +585,3 @@ def mx_barrier(group: Group | None):
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -34,6 +34,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -234,15 +235,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -280,18 +288,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
await self._start_runner_task(modified_task)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.

View File

@@ -328,8 +328,7 @@ def _pending_tasks(
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> CancelTask | None:
"""Find a cancelled task that hasn't been sent to the runner yet."""
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue

View File

@@ -67,9 +67,7 @@ def entrypoint(
try:
event_sender.close()
task_receiver.close()
cancel_receiver.close()
finally:
event_sender.join()
task_receiver.join()
cancel_receiver.join()
logger.info("bye from the runner")

View File

@@ -1,10 +1,11 @@
import base64
import json
import math
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,6 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -91,8 +93,6 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -138,7 +138,6 @@ def main(
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
@@ -204,17 +203,8 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
@@ -243,7 +233,7 @@ def main(
assert inference_model
assert tokenizer
t = time.perf_counter()
t = time.monotonic()
toks = warmup_inference(
model=inference_model,
tokenizer=tokenizer,
@@ -251,7 +241,7 @@ def main(
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(
@@ -320,11 +310,31 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
if isinstance(inference_model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
@@ -577,8 +587,21 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -635,9 +658,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -758,55 +781,221 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse], tool_parser: ToolParser
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response.text.startswith(tool_parser.start_parsing):
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -72,8 +72,8 @@ class RunnerSupervisor:
initialize_timeout: float
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_cancel_sender: MpSender[TaskId]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_pipe_read_fd: int | None = None # Python reads runner's pipe output
_pipe_write_fd: int | None = None # Python writes gathered data to runner
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
@@ -185,9 +185,9 @@ class RunnerSupervisor:
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self._event_sender.close()
self._close_pipe_fds()
self.runner_process.join(1)
if not self.runner_process.is_alive():
@@ -226,7 +226,6 @@ class RunnerSupervisor:
await event.wait()
async def cancel_task(self, task_id: TaskId):
"""Send a cancellation signal to the runner process."""
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return

View File

@@ -1,72 +0,0 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -1,7 +1,9 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -115,12 +117,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings.
def _mock_all_gather(x: object, **_kw: object) -> object:
return x
monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -182,15 +178,16 @@ def _run(tasks: Iterable[Task]):
# this is some c++ nonsense
task_receiver.close = nothin
task_receiver.join = nothin
cancel_receiver.close = nothin
cancel_receiver.join = nothin
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
return event_sender.events

View File

@@ -5,13 +5,12 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -23,13 +22,10 @@ def _make_responses(
)
def _dummier_parser(text: str) -> dict[str, Any]:
def _dummy_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -39,6 +35,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -52,6 +50,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,7 +76,9 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
"<tool_call>",
"</tool_call>",
_failing_parser,
)
)

40
uv.lock generated
View File

@@ -377,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -416,7 +416,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
@@ -1020,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1048,12 +1048,18 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1066,14 +1072,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1104,7 +1102,7 @@ version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1116,6 +1114,16 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"