Compare commits

...

21 Commits

Author SHA1 Message Date
Alex Cheema
ce45af58e3 fix: --no-downloads no longer blocks model loading (#1510)
When --no-downloads is passed, skip download checks in plan() so
pre-staged models can be loaded without DownloadCompleted events.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:40:12 -08:00
rltakashige
f2be929211 Leo/address rdma gpu locks 2 (#1515)
Same as #1489 . Had to revert and redo thanks to Claude.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:00:52 -08:00
rltakashige
83af8c63fa Revert "Use custom fork that resolves GPU locks" (#1502)
Reverts exo-explore/exo#1489

Goddammit Claude...
2026-02-17 18:18:54 +00:00
Evan Quiney
eccc6298d1 Revert "Add MetaInstance declarative layer (#1447)"
This reverts commit a962a28afc.
2026-02-17 18:11:47 +00:00
Evan Quiney
c8997217cf Revert "feat: better onboarding UX for new users (#1479)"
This reverts commit 490d2e46ba.
2026-02-17 18:02:32 +00:00
Alex Cheema
490d2e46ba feat: better onboarding UX for new users (#1479)
## Summary

- **Auto-open dashboard** in browser on first launch (uses
`~/.exo/.dashboard_opened` marker)
- **Welcome overlay** with "Choose a Model" CTA button when no model
instance is running
- **Tutorial progress messages** during model download → loading → ready
lifecycle stages
- **Fix conversation sidebar** text contrast — bumped to white text,
added active state background
- **Simplify technical jargon** — sharding/instance type/min nodes
hidden behind collapsible "Advanced Options" toggle; strategy display
hidden behind debug mode
- **Polished DMG installer** with drag-to-Applications layout, custom
branded background, and AppleScript-configured window positioning

## Test plan

- [ ] Launch exo for the first time (delete `~/.exo/.dashboard_opened`
to simulate) — browser should auto-open
- [ ] Verify welcome overlay appears on topology when no model is loaded
- [ ] Launch a model and verify download/loading/ready messages appear
in instance cards
- [ ] Check conversation sidebar text is readable (white on dark, yellow
when active)
- [ ] Verify "Advanced Options" toggle hides/shows sharding controls
- [ ] Build DMG with `packaging/dmg/create-dmg.sh` and verify
drag-to-Applications layout

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 17:52:49 +00:00
rltakashige
facf2d4d03 Use custom fork that resolves GPU locks (#1489)
## Motivation

There is an issue on Macs that means that an explicit synchronization is
necessary for memory to be updated from L1 cache. This means that GPU
locks can occur when a spin wait does not see the updated timestamp.

## Changes

Updated in my own personal fork.

## Why It Works

https://github.com/ARM-software/acle/releases

## Test Plan

### Manual Testing
Tested manually that no GPU locks occur (even with multiple simultaneous
instances running) and that the performance differential is negligible
(267 vs 269 tps on Llama 3.2 1B at an approx 10k context.)


------------------------------------------------------
I have seen a GPU lock, specifically when sending a particularly large
chat completion while the model was loading. However, I have since been
unable to reproduce and this may be something I did wrong. Please do
create an issue and tag me if any GPU locks do occur.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 17:48:43 +00:00
Alex Cheema
a962a28afc Add MetaInstance declarative layer (#1447)
## Motivation

Users currently manage instances directly, which means if a node
disconnects or connections break, the instance dies and nothing
recreates it. MetaInstance is a declarative primitive: "ensure an
instance matching these parameters always exists." The reconciler
watches for unhealthy or missing backing instances and re-places them
automatically.

## Changes

- **MetaInstance type** (`meta_instance.py`): declarative constraint
with `model_id`, `min_nodes`, optional `node_ids`, and `sharding`
- **Reconciler** (`reconcile.py`): `find_unsatisfied_meta_instances`
checks which MetaInstances lack a healthy backing instance,
`try_place_for_meta_instance` creates one
- **Master loop** (`main.py`): periodically reconciles unsatisfied
MetaInstances; immediate placement on `CreateMetaInstance` command
- **API** (`api.py`): `create_meta_instance` / `delete_meta_instance` /
`GET /meta_instances` endpoints; delete cascades to backing instances
with task cancellation
- **Binding via `meta_instance_id` on Instance** (`instances.py`): no
separate binding event or backing map — the instance carries its parent
MetaInstance ID directly, eliminating race conditions in the reconciler
- **Dashboard**: sidebar shows MetaInstances with their backing instance
status; orphan instances (created directly) still shown separately
- **Tests**: constraint matching, connection health, unsatisfied
detection, exclusive binding, cascade delete with task cancellation

### Recent improvements

- **fix: cancel active tasks on cascade delete** — `DeleteMetaInstance`
now emits `TaskStatusUpdated(Cancelled)` for any Pending/Running tasks
on backing instances before emitting `InstanceDeleted`. Previously,
cascade-deleting backing instances left orphaned task references in
state.
- **Lifecycle logging** — added `logger.info`/`logger.warning` for:
`CreateMetaInstance` (model, min_nodes, sharding), `DeleteMetaInstance`
(with cascade count), reconciler placement success/failure, and retry
decisions with attempt counts in `InstanceHealthReconciler`.
- **GET `/meta_instances` endpoint** — lists all meta-instances without
needing to fetch full state.
- **2 regression tests** — `test_cascade_delete_cancels_active_tasks`
and `test_cascade_delete_skips_completed_tasks` verify the
cascade-delete event sequence.

## Why It Works

Putting `meta_instance_id` on `BaseInstance` makes binding inherent to
instance creation. When the reconciler creates an instance for a
MetaInstance, it tags it via `model_copy`. When the instance is deleted,
the binding disappears with it. This avoids the two bugs that a separate
binding mechanism would introduce:
1. Stale exclusion sets — the reconciler loop can't accidentally bind
two MetaInstances to the same instance
2. Delete ordering race — no window between deleting an instance and its
binding where the reconciler could re-place

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
- Created MetaInstance via dashboard, verified instance placed
- Verified delete cascades (deleting MetaInstance removes backing
instance)
- Verified orphan instances still work independently

### Automated Testing
- 30 tests in `test_meta_instance_edge_cases.py`: lifecycle, retry
logic, error handling, concurrent operations, cascade delete with task
cancellation
- 24 tests in `test_reconcile.py`: constraint matching, connection
health (single/multi-node, edge removal, IP changes), unsatisfied
detection, exclusive binding, idempotency
- All 261 tests pass
- basedpyright 0 errors, ruff clean, dashboard builds

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:48:19 -08:00
Alex Cheema
db79c350c1 Fix graceful process shutdown in macOS app (#1372)
## Motivation

Fixes #1370

When the macOS app stops exo, GPU/system memory isn't released. This
happens because:

1. The macOS app calls `process.terminate()` (SIGTERM) but the Python
process only registers a graceful shutdown handler for SIGINT, not
SIGTERM. SIGTERM's default Python behavior raises `SystemExit` which
bypasses the cleanup cascade (runner subprocess MLX cleanup via
`mx.clear_cache()`, channel closing, etc.).
2. The app doesn't wait for the process to actually finish cleanup — it
immediately nils out the process reference.

## Changes

**`src/exo/main.py`**: Register SIGTERM handler alongside SIGINT so the
graceful shutdown cascade (`Node.shutdown()` → cancel task group →
worker/runner cleanup → `mx.clear_cache()` + `gc.collect()`) runs
regardless of which signal is received.

**`app/EXO/EXO/ExoProcessController.swift`**: Replace immediate
`process.terminate()` with escalating shutdown per @Evanev7's
suggestion:
1. Send SIGINT via `process.interrupt()` — triggers the registered
Python handler for graceful cleanup
2. Wait up to 5 seconds for the process to exit
3. If still running, escalate to SIGTERM via `process.terminate()`
4. Wait up to 3 seconds
5. If still running, force kill via SIGKILL

The escalation runs in a detached `Task` so the UI updates immediately
(status → stopped) without blocking.

## Why It Works

The root cause is that SIGTERM wasn't triggering the graceful shutdown
path. By registering a SIGTERM handler in Python and sending SIGINT
first from the macOS app, the process gets a chance to run the full
cleanup cascade: cancelling the task group, shutting down runners (which
call `del model; mx.clear_cache(); gc.collect()`), closing channels, and
flushing logs. The escalation to SIGTERM and SIGKILL ensures the process
always terminates even if graceful shutdown hangs.

## Test Plan

### Manual Testing
<!-- Hardware: Mac Studio M4 Max 128GB -->
- Start exo via macOS app, load a model, run inference
- Stop via the toggle switch, verify memory is released without
requiring a system restart
- Test rapid stop/start (restart) to ensure no race conditions

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — no changes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan Quiney <evanev7@gmail.com>
2026-02-17 09:03:54 -08:00
Alex Cheema
d6301ed593 dashboard: redesign downloads page as model×node table (#1465)
## Motivation

The current downloads page uses a node-centric card grid layout that is
messy and hard to read — the same model across different nodes appears
in separate cards, and deep nesting wastes space. This makes it
difficult to quickly see which models are on which nodes.

## Changes

Rewrote the downloads page
(`dashboard/src/routes/downloads/+page.svelte`) from a card grid to a
clean table layout:

- **Rows** = models (unique across all nodes)
- **Columns** = nodes (with disk free shown in header)
- **Cells** show status at a glance:
  -  Green checkmark + size for completed downloads
  - 🟡 Yellow percentage + mini progress bar + speed for active downloads
  - `...` for pending downloads
  -  Red X for failed downloads
  - `--` for models not present on a node
- Delete/download action buttons appear on row hover
- Model name column is sticky on horizontal scroll (for many-node
clusters)
- Models sorted by number of nodes with completed downloads
- Imported shared utilities from `$lib/utils/downloads` instead of
inline re-implementations

### Backend: model directory in download events

- Added `model_directory` field to `BaseDownloadProgress` so all
download status events include the on-disk path
- Added `_model_dir()` helper to `DownloadCoordinator` to compute the
path from `EXO_MODELS_DIR`
- Dashboard uses this to show file location and enable "open in Finder"
for completed downloads

### Info modal

- Clicking a model name opens an info modal showing card details
(family, quantization, capabilities, storage size, layer count, tensor
parallelism support)

### Other fixes

- Fixed model name truncation in the table
- Excluded `tests/start_distributed_test.py` from pytest collection (CLI
script that calls `sys.exit()` at import time)

## Test Plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all passed
- [x] `nix fmt` — clean
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:31:47 +00:00
Evan Quiney
6d1ca6689b don't time out node identities (#1493)
currently nodes leaving and rejoining the cluster can lose their identity. We have no need to delete this data on node timing out, so let's just persist it.
2026-02-17 11:48:28 +00:00
Evan
c01b6fff21 eprint banner
our banner was being printed to stdout but should be printed to stderr
as its essentially a log message
2026-02-17 11:43:06 +00:00
Jake Hillion
8392e78afe bench: add spec for automatic canary benchmarks (#1483)
Adds all the models that can fit onto a single M3 Ultra for single
machine benchmarks. Fixes the macOS version, GPU spec, and chip type for
maximum reproducibility. Specifies the minimum memory accordingly for
each type of model, using the smallest machine available (the smallest
M3 Ultra is 96GiB).

Test plan:
- Running this with some code that makes machines of this spec available
and stores the results. It works.

This will become part of a larger testing/stability strategy once we've
collected more of the data.
2026-02-17 10:52:05 +00:00
Evan
86735ece78 begins
begins
2026-02-16 19:26:19 +00:00
Evan Quiney
2759e92334 api cancellation (#1276)
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks

closes #61

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:30:07 -08:00
Jake Hillion
131fb141a6 bench: add --danger-delete-downloads flag with planning phase
exo bench previously relied on the worker's plan loop to download
models, which could fail silently or run into disk space issues during
benchmarking. This made it difficult to diagnose download failures.

Added a planning phase that runs before benchmarking to explicitly
handle downloads. It checks available disk space on each node via the
/state endpoint and starts downloads via POST /download/start. When
the --danger-delete-downloads flag is set and there's insufficient
space, it deletes existing models from smallest to largest until
there's room for the benchmark model.

Test plan:
- CI

```
jake@maverick:/data/users/jake/repos/exo/ > nix run .#exo-bench -- --pp 128,2048,4096 --tg 128 --stdout --settle-timeout 10 --host s1 --model mlx-community/gpt-oss-120b-MXFP4-Q8
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-16 12:12:11.807 | INFO     | __main__:main:710 - pp/tg mode: combinations (product) - 3 pairs
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
2026-02-16 12:12:13.455 | DEBUG    | __main__:main:725 - [exo-bench] loaded tokenizer: mlx-community/gpt-oss-120b-MXFP4-Q8 for prompt sizer
2026-02-16 12:12:13.473 | DEBUG    | __main__:main:761 - exo-bench model: short_id=gpt-oss-120b-MXFP4-Q8 full_id=mlx-community/gpt-oss-120b-MXFP4-Q8
2026-02-16 12:12:13.473 | INFO     | __main__:main:762 - placements: 1
2026-02-16 12:12:13.474 | INFO     | __main__:main:764 -   - Pipeline / MlxRing / nodes=1
2026-02-16 12:12:13.474 | INFO     | __main__:main:771 - Planning phase: checking downloads...
Traceback (most recent call last):
  File "/nix/store/q31kmbcfr5bf97290bvbnhrvpc3fv824-source/bench/exo_bench.py", line 885, in <module>
    raise SystemExit(main())
                     ~~~~^^
  File "/nix/store/q31kmbcfr5bf97290bvbnhrvpc3fv824-source/bench/exo_bench.py", line 772, in main
    run_planning_phase(
    ~~~~~~~~~~~~~~~~~~^
        client,
        ^^^^^^^
    ...<4 lines>...
        settle_deadline,
        ^^^^^^^^^^^^^^^^
    )
    ^
  File "/nix/store/q31kmbcfr5bf97290bvbnhrvpc3fv824-source/bench/exo_bench.py", line 367, in run_planning_phase
    raise RuntimeError(
    ...<2 lines>...
    )
RuntimeError: Insufficient disk on 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj: need 65GB, have 55GB. Use --danger-delete-downloads to free space.
jake@maverick:/data/users/jake/repos/exo/ > nix run .#exo-bench -- --pp 128,2048,4096 --tg 128 --stdout --settle-timeout 10 --host s1 --model mlx-community/gpt-oss-120b-MXFP4-Q8 --danger-delete-downloads
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-16 12:12:19.626 | INFO     | __main__:main:710 - pp/tg mode: combinations (product) - 3 pairs
2026-02-16 12:12:21.262 | DEBUG    | __main__:main:725 - [exo-bench] loaded tokenizer: mlx-community/gpt-oss-120b-MXFP4-Q8 for prompt sizer
2026-02-16 12:12:21.280 | DEBUG    | __main__:main:761 - exo-bench model: short_id=gpt-oss-120b-MXFP4-Q8 full_id=mlx-community/gpt-oss-120b-MXFP4-Q8
2026-02-16 12:12:21.280 | INFO     | __main__:main:762 - placements: 1
2026-02-16 12:12:21.280 | INFO     | __main__:main:764 -   - Pipeline / MlxRing / nodes=1
2026-02-16 12:12:21.280 | INFO     | __main__:main:771 - Planning phase: checking downloads...
2026-02-16 12:12:21.336 | INFO     | __main__:run_planning_phase:386 - Deleting mlx-community/Qwen3-0.6B-4bit from 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj (335MB)
2026-02-16 12:12:21.350 | INFO     | __main__:run_planning_phase:386 - Deleting mlx-community/Llama-3.2-1B-Instruct-4bit from 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj (679MB)
2026-02-16 12:12:21.363 | INFO     | __main__:run_planning_phase:386 - Deleting mlx-community/Llama-3.2-3B-Instruct-4bit from 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj (1740MB)
2026-02-16 12:12:21.373 | INFO     | __main__:run_planning_phase:386 - Deleting mlx-community/Llama-3.2-3B-Instruct-8bit from 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj (3264MB)
2026-02-16 12:12:21.384 | INFO     | __main__:run_planning_phase:386 - Deleting mlx-community/GLM-4.7-Flash-8bit from 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj (30366MB)
2026-02-16 12:12:21.413 | INFO     | __main__:run_planning_phase:407 - Started download on 12D3KooWE2C7dzC9d9YJMEfWK3g8og7JdZj3HHXZ8VmGrXYAEnEj
```

It's not pretty but it works!
2026-02-16 13:06:38 +00:00
Evan Quiney
2d8bfc2e3c fix: PlaceInstanceParams broken field validator
our field validator for PlaceInstance was wrong - we can just rely on default behaviour here anyway!
2026-02-16 03:58:43 -08:00
ciaranbor
042999f728 Ciaran/message deletion (#1409)
## Motivation

When a user deletes a message during an active streamed generation, it
can cause unexpected behavior. The delete confirmation text was also
misleading — it said "all responses after it" only for user messages,
which didn't accurately describe the behavior (all messages after the
deleted one are removed, regardless of role)

## Changes

- Prevent deletion during streaming: Disabled the delete button and
blocked handleDeleteClick when loading is true, with a visual indication
(dimmed button, cursor-not-allowed, tooltip change)
- Clarified delete confirmation text: Replaced role-specific wording
with a simpler, accurate message:
    - Last message: "Delete this message?"
- Any other message: "Delete this message and all messages after it?"

## Why It Works

Guarding on the loading state at both the click handler and the button's
disabled attribute ensures no deletion can be triggered while a response
is being streamed

## Test Plan

### Manual Testing

- Verify the delete button is visually disabled and non-clickable while
a response is streaming
- Verify the tooltip shows "Cannot delete while generating" during
streaming
  - Verify the last message shows "Delete this message?" confirmation
- Verify non-last messages show "Delete this message and all messages
after it?" confirmation
  - Verify deletion works normally when not streaming
2026-02-16 11:46:41 +00:00
ciaranbor
b61dc2eb35 Prevent image editing without image input (#1410)
## Motivation

Models that only support image editing (ImageToImage but not
TextToImage) would silently attempt text-to-image generation when a user
submitted a text prompt without an attached image

## Changes

- Added an early return guard in handleSubmit() that prevents submission
when the selected model only supports image editing and no image is
attached (isEditOnlyWithoutImage)
- Fixed the text-to-image generation branch to use the more specific
modelSupportsTextToImage() check instead of the broad isImageModel(),
ensuring only models with TextToImage capability trigger generation from
text alone
- The existing isEditOnlyWithoutImage derived state (which was already
used for UI hints like placeholder text and button disabling) now also
blocks the actual submit path

## Why It Works

The text-to-image fallback now correctly checks
modelSupportsTextToImage() directly, so edit-only models no longer fall
through to the generation path


## Test Plan

### Manual Testing

- Select an edit-only image model (e.g., one with only ImageToImage
capability)
- Verify the send button is disabled and placeholder reads "Attach an
image to edit..." when no image is attached
  - Attach an image and verify the form becomes submittable
- Select a text-to-image model and verify text-only prompts still
trigger generation normally
  - Ensure pressing `enter` doesn't bypass check
2026-02-16 11:39:59 +00:00
rltakashige
36a7115b6f Pass usage and generation stats through all adapters correctly (#1461)
## Motivation

Exo is not returning usage stats correctly at the moment.

## Changes

- Correctly pass usage stats instead of generation stats.
- Pass usage stats within tool calls.

## Test Plan

### Manual Testing
Needs manual testing.

### Automated Testing
Passes CI.
2026-02-16 11:20:04 +00:00
Jake Hillion
0b7d88b43b python: add hermetic basedpyright typecheck to nix flake check
The existing CI typecheck job used `uv run basedpyright` which depends
on a non-hermetic uv sync step. This replaces it with a fully hermetic
typecheck as a Nix flake check using the uv2nix virtual environment.

Added a `typecheckVenv` with dev dependencies, a `linuxOverlay` to
ignore native shared library deps (NVIDIA, torch, triton, mlx) that
aren't needed at type-check time, and `passthru` preservation plus
`.pyi` stub copying on the `exo-pyo3-bindings` overlay so basedpyright
can resolve the Rust bindings types. Also guarded the `mlx` Nix build
override to macOS only since it requires Metal. Removed the old
non-hermetic `typecheck` CI job since `nix flake check` now covers it.

The hermetic check ensures type checking uses exactly the locked
dependency versions and catches type errors without requiring a
working uv/pip environment.

Test plan:
- CI (`nix flake check` runs on x86_64-linux, aarch64-linux, aarch64-darwin)
- Verified `nix build ".#checks.x86_64-linux.typecheck"` passes with 0 errors
2026-02-16 11:09:23 +00:00
44 changed files with 1683 additions and 1003 deletions

View File

@@ -8,33 +8,6 @@ on:
- main
jobs:
typecheck:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Sync dependencies
run: uv sync --all-packages
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}

View File

@@ -5,21 +5,21 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).

View File

@@ -72,16 +72,23 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -126,11 +126,37 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
if process.isRunning {
process.terminate()
}
self.process = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
}
func restart() {

7
bench/bench.toml Normal file
View File

@@ -0,0 +1,7 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

View File

@@ -288,6 +288,151 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
@@ -535,6 +680,11 @@ def main() -> int:
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -569,13 +719,16 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and args.settle_timeout > 0:
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + args.settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
@@ -607,6 +760,16 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

189
bench/single-m3-ultra.toml Normal file
View File

@@ -0,0 +1,189 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

@@ -265,6 +265,7 @@
function handleSubmit() {
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
if (isEditOnlyWithoutImage) return;
const content = message.trim();
const files = [...uploadedFiles];
@@ -289,7 +290,11 @@
if (imageFile.preview) {
editImage(content, imageFile.preview);
}
} else if (isImageModel() && content) {
} else if (
currentModel &&
modelSupportsTextToImage(currentModel) &&
content
) {
// Use image generation for text-to-image models
generateImage(content);
} else {

View File

@@ -225,6 +225,7 @@
}
function handleDeleteClick(messageId: string) {
if (loading) return;
deleteConfirmId = messageId;
}
@@ -255,7 +256,7 @@
</script>
<div class="flex flex-col gap-4 sm:gap-6 {className}">
{#each messageList as message (message.id)}
{#each messageList as message, i (message.id)}
<div
class="group flex {message.role === 'user'
? 'justify-end'
@@ -317,9 +318,11 @@
<!-- Delete confirmation -->
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
<p class="text-xs text-red-400 mb-3">
Delete this message{message.role === "user"
? " and all responses after it"
: ""}?
{#if i === messageList.length - 1}
Delete this message?
{:else}
Delete this message and all messages after it?
{/if}
</p>
<div class="flex gap-2 justify-end">
<button
@@ -751,8 +754,13 @@
<!-- Delete button -->
<button
onclick={() => handleDeleteClick(message.id)}
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
title="Delete message"
disabled={loading}
class="p-1.5 transition-colors rounded {loading
? 'text-exo-light-gray/30 cursor-not-allowed'
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
title={loading
? "Cannot delete while generating"
: "Delete message"}
>
<svg
class="w-3.5 h-3.5"

View File

File diff suppressed because it is too large Load Diff

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.6"; in
version = let v = "0.30.7.dev20260217+50487b41"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
};
patches = [

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,6 +64,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
@@ -132,7 +133,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -14,7 +14,9 @@
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
@@ -22,6 +24,12 @@
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
passthru = prev.exo-pyo3-bindings.passthru or { };
postInstall = ''
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
touch $siteDir/py.typed
'';
};
};
@@ -29,17 +37,47 @@
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
# Use our pure Nix-built MLX with Metal support (macOS only)
mlx = self'.packages.mlx;
};
# Additional overlay for Linux-specific fixes (type checking env).
# Native wheels have shared lib dependencies we don't need at type-check time.
linuxOverlay = final: prev:
let
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
in
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
);
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
@@ -48,16 +86,28 @@
overlay
exoOverlay
buildSystemsOverlay
linuxOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;
@@ -118,6 +168,21 @@
${pkgs.ruff}/bin/ruff check ${inputs.self}
touch $out
'';
# Hermetic basedpyright type checking
typecheck = pkgs.runCommand "typecheck"
{
nativeBuildInputs = [
testVenv
pkgs.basedpyright
];
}
''
cd ${inputs.self}
export HOME=$TMPDIR
basedpyright --pythonpath ${testVenv}/bin/python
touch $out
'';
};
};
}

View File

@@ -14,6 +14,7 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -63,6 +64,9 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -74,6 +78,7 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -93,6 +98,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -170,7 +176,11 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -184,6 +194,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -206,6 +217,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -219,6 +231,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -253,6 +266,7 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -295,11 +309,18 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
status = DownloadOngoing(
@@ -308,6 +329,9 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -94,6 +94,7 @@ class Node:
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
no_downloads=args.no_downloads,
)
else:
worker = None
@@ -136,6 +137,8 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -147,8 +150,6 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -225,6 +226,7 @@ class Node:
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
no_downloads = self.worker.no_downloads
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
@@ -239,6 +241,7 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
no_downloads=no_downloads,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -17,6 +17,7 @@ from exo.shared.types.api import (
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -125,6 +126,8 @@ async def generate_chat_stream(
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
@@ -138,6 +141,8 @@ async def generate_chat_stream(
yield "data: [DONE]\n\n"
return
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
@@ -161,12 +166,15 @@ async def generate_chat_stream(
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
@@ -176,7 +184,9 @@ async def generate_chat_stream(
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
@@ -184,6 +194,7 @@ async def collect_chat_response(
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
@@ -193,6 +204,8 @@ async def collect_chat_response(
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
@@ -223,7 +236,7 @@ async def collect_chat_response(
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
yield ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
@@ -241,4 +254,6 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
)
usage=last_usage,
).model_dump_json()
return

View File

@@ -4,7 +4,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
@@ -161,12 +161,14 @@ async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ClaudeMessagesResponse:
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -174,6 +176,8 @@ async def collect_claude_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
@@ -183,12 +187,10 @@ async def collect_claude_response(
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -208,11 +210,11 @@ async def collect_claude_response(
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
# Use actual usage data if available
input_tokens = last_usage.prompt_tokens if last_usage else 0
output_tokens = last_usage.completion_tokens if last_usage else 0
return ClaudeMessagesResponse(
yield ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=content,
@@ -221,7 +223,8 @@ async def collect_claude_response(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
).model_dump_json()
return
async def generate_claude_stream(
@@ -249,7 +252,7 @@ async def generate_claude_stream(
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
@@ -257,8 +260,9 @@ async def generate_claude_stream(
# Close text block and bail
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
@@ -290,7 +294,6 @@ async def generate_claude_stream(
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
@@ -302,9 +305,9 @@ async def generate_claude_stream(
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# Use actual token count from usage if available
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)

View File

@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
@@ -121,13 +122,15 @@ async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ResponsesResponse:
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -135,32 +138,32 @@ async def collect_responses_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
id=tool.id,
call_id=tool.id,
name=tool.name,
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
output: list[ResponseItem] = [
@@ -172,14 +175,15 @@ async def collect_responses_response(
]
output.extend(function_call_items)
return ResponsesResponse(
yield ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
).model_dump_json()
return
async def generate_responses_stream(
@@ -235,15 +239,16 @@ async def generate_responses_stream(
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{tool.id}"
call_id = f"call_{tool.id}"
@@ -302,7 +307,6 @@ async def generate_responses_stream(
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
@@ -346,13 +350,13 @@ async def generate_responses_stream(
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
# response.completed

View File

@@ -125,6 +125,7 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
TextGeneration,
)
@@ -540,16 +541,14 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
@@ -644,11 +643,14 @@ class API:
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
else:
return StreamingResponse(
collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -664,8 +666,7 @@ class API:
command = TextGeneration(task_params=task_params)
await self._send(command)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
return await self._collect_text_generation_with_stats(command.command_id)
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
@@ -883,6 +884,11 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -964,6 +970,11 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -1221,12 +1232,15 @@ class API:
"X-Accel-Buffering": "no",
},
)
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
else:
return StreamingResponse(
collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def openai_responses(
self, payload: ResponsesRequest
@@ -1254,11 +1268,15 @@ class API:
},
)
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
else:
return StreamingResponse(
collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""

View File

@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
TextGeneration,
@@ -39,6 +40,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
TracesMerged,
@@ -279,7 +281,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
@@ -299,7 +301,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -309,7 +311,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -319,6 +321,18 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -327,10 +341,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time

View File

@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
@@ -186,6 +192,7 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -201,6 +208,18 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -4,7 +4,11 @@ import json
from collections.abc import AsyncGenerator
from typing import Any, cast
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
from exo.master.adapters.claude import (
ClaudeMessagesResponse,
collect_claude_response,
generate_claude_stream,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, ModelId
@@ -17,6 +21,18 @@ async def _chunks_to_stream(
yield chunk
async def _collect_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Helper to consume the async generator and parse the JSON response."""
parts: list[str] = []
async for part in collect_claude_response(command_id, model, chunk_stream):
parts.append(part)
return ClaudeMessagesResponse.model_validate_json("".join(parts))
MODEL = ModelId("test-model")
COMMAND_ID = CommandId("cmd_test123")
@@ -47,7 +63,7 @@ class TestCollectClaudeResponseToolUse:
],
),
]
response = await collect_claude_response(
response = await _collect_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
@@ -77,7 +93,7 @@ class TestCollectClaudeResponseToolUse:
],
),
]
response = await collect_claude_response(
response = await _collect_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
@@ -102,7 +118,7 @@ class TestCollectClaudeResponseToolUse:
],
),
]
response = await collect_claude_response(
response = await _collect_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
@@ -116,7 +132,7 @@ class TestCollectClaudeResponseToolUse:
async def test_no_content_produces_empty_text_block(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
response = await collect_claude_response(
response = await _collect_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert len(response.content) == 1

View File

@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
target_instances = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 0
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1

View File

@@ -218,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -263,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -3,8 +3,7 @@ from collections.abc import Generator
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
@@ -228,13 +227,6 @@ class PlaceInstanceParams(BaseModel):
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
def use_default(cls, v: object):
if not v or not isinstance(v, (Sharding, InstanceMeta)):
raise PydanticUseDefault()
return v
class CreateInstanceParams(BaseModel):
instance: Instance

View File

@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -89,6 +93,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +93,7 @@ Task = (
| LoadModel
| StartWarmup
| TextGeneration
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -1,5 +1,7 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner)
print(banner, file=sys.stderr)

View File

@@ -125,7 +125,9 @@ class MpSender[T]:
self._state.buffer.put(item, block=True)
async def send_async(self, item: T) -> None:
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
await to_thread.run_sync(
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Ready to prefill")
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
@@ -393,10 +393,11 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
total_prompt_tokens = len(all_prompt_tokens)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
total_tokens=total_prompt_tokens + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),

View File

@@ -64,8 +64,6 @@ from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -83,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -379,7 +353,13 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
tokenizer = load_tokenizer(
model_path,
@@ -591,3 +571,61 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -33,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -64,6 +65,7 @@ class Worker:
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
no_downloads: bool = False,
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -73,6 +75,7 @@ class Worker:
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.no_downloads = no_downloads
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
@@ -181,6 +184,7 @@ class Worker:
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
no_downloads=self.no_downloads,
)
if task is None:
continue
@@ -224,15 +228,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -270,18 +281,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
await self._start_runner_task(modified_task)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
@@ -320,8 +331,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -50,16 +51,22 @@ def plan(
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
no_downloads: bool = False,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or (
None
if no_downloads
else _model_needs_download(node_id, runners, global_download_status)
)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _load_model(runners, all_runners, global_download_status, no_downloads)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -190,22 +197,25 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
no_downloads: bool = False,
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
if not no_downloads:
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id
== shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
)
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
if not all_local_downloads_complete:
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
@@ -270,7 +280,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +294,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +302,33 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner_id, runner in runners.items():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id,
cancelled_task_id=task.task_id,
runner_id=runner_id,
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,10 +1,10 @@
import base64
import json
import math
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, Literal
from typing import Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -88,9 +87,12 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -112,6 +114,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
@@ -129,11 +132,16 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -146,6 +154,7 @@ def main(
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -191,19 +200,28 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -211,8 +229,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -224,16 +240,31 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
t = time.monotonic()
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
group=group,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
@@ -241,8 +272,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -262,9 +293,9 @@ def main(
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert check_for_cancel_every
try:
_check_for_debug_prompts(task_params)
@@ -274,7 +305,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -289,34 +320,25 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
completion_tokens = 0
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
completion_tokens += 1
@@ -364,6 +386,7 @@ def main(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
stats=response.stats,
),
)
)
@@ -388,7 +411,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -401,7 +424,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
@@ -451,7 +476,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -464,7 +489,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
@@ -523,14 +550,20 @@ def main(
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
was_cancelled = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if not was_cancelled:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
@@ -544,21 +577,8 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -615,9 +635,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -738,218 +758,55 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
responses: Generator[GenerationResponse], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
if response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
"tool call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=None,
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -47,9 +47,11 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -60,8 +62,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -69,6 +71,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -83,6 +86,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,6 +101,8 @@ class RunnerSupervisor:
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
@@ -112,14 +118,6 @@ class RunnerSupervisor:
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
@@ -141,6 +139,17 @@ class RunnerSupervisor:
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
with anyio.move_on_after(0.5) as scope:
await self._cancel_sender.send_async(task_id)
if scope.cancel_called:
logger.error("RunnerSupervisor cancel pipe blocked")
await self._check_runner(TimeoutError("cancel pipe blocked"))
async def _forward_events(self):
with self._ev_recv as events:
try:

View File

@@ -0,0 +1,72 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -157,6 +157,41 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_loads_model_with_no_downloads_flag():
"""
When no_downloads=True, plan() should skip download checks and proceed
to load the model even with empty global_download_status.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
global_download_status={},
instances=instances,
all_runners=all_runners,
tasks={},
no_downloads=True,
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
"""
LoadModel should not be emitted while some shards are still missing from

View File

@@ -1,7 +1,9 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -19,6 +21,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
@@ -113,6 +116,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
with task_sender:
@@ -173,8 +178,16 @@ def _run(tasks: Iterable[Task]):
# this is some c++ nonsense
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
return event_sender.events

View File

@@ -5,12 +5,13 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -22,10 +23,13 @@ def _make_responses(
)
def _dummy_parser(text: str) -> dict[str, Any]:
def _dummier_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -35,8 +39,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -50,8 +52,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,9 +76,7 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_failing_parser,
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
)
)

40
uv.lock generated
View File

@@ -377,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -416,7 +416,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
@@ -1020,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1048,18 +1048,12 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
resolution-markers = [
"sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1072,6 +1066,14 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1102,7 +1104,7 @@ version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1114,16 +1116,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"