Compare commits

...

33 Commits

Author SHA1 Message Date
Alex Cheema
acb97127bf Normalize TextGenerationTaskParams.input to list[InputMessage] (#1360)
## Motivation

With the addition of the Responses API, we introduced `str |
list[InputMessage]` as the type for `TextGenerationTaskParams.input`
since the Responses API supports sending input as a plain string. But
there was no reason to leak that flexibility past the API adapter
boundary — it just meant every downstream consumer had to do `if
isinstance(messages, str):` checks, adding complexity for no benefit.

## Changes

- Changed `TextGenerationTaskParams.input` from `str |
list[InputMessage]` to `list[InputMessage]`
- Each API adapter (Chat Completions, Claude Messages, Responses) now
normalizes to `list[InputMessage]` at the boundary
- Removed `isinstance(task_params.input, str)` branches in
`utils_mlx.py` and `runner.py`
- Wrapped string inputs in `[InputMessage(role="user", content=...)]` in
the warmup path and all test files

## Why It Works

The API adapters are the only place where we deal with raw user input
formats. By normalizing there, all downstream code (worker, runner, MLX
engine) can just assume `list[InputMessage]` and skip the type-checking
branches. The type system (`basedpyright`) catches any missed call sites
at compile time.

## Test Plan

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — applied
- `uv run pytest` — 174 passed, 1 skipped

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:01:56 -08:00
Evan Quiney
d90605f198 migrate model cards to .toml files (#1354) 2026-02-03 12:32:06 +00:00
Evan Quiney
f400b4d7c5 fix InstanceViewModel.swift (#1359)
wasn't caught when we merged the API changes
2026-02-02 18:43:27 +00:00
Evan Quiney
d97bca88e6 improve distributed testing (#1300)
Our distributed test now does a full query cycle for every model loaded
onto the relevant machine. This will help find bugs early, as it already
has found one with Qwen3 Next! I didn't write down what the error was
though. Gooooooood luck with that!

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-02 18:25:39 +00:00
Alex Cheema
dfce188d99 fix: handle unclosed tool calls and GLM arg parsing edge cases (#1344)
## Motivation

Tool-call requests can hang indefinitely when `max_tokens` truncates
generation mid-tool-call.

## Reproduction

1. Send a chat completion with `tools` and a low `max_tokens` (e.g. 65)
to Qwen3-0.6B
2. Model generates `<think>...</think>` then starts `<tool_call>` but
`max_tokens` cuts it off before `</tool_call>`
3. **Before this fix:** `parse_tool_calls` buffers tokens after
`<tool_call>`, generator exhausts, buffered tokens (including
`finish_reason`) are silently dropped → stream hangs forever
4. **After this fix:** buffered tokens are flushed as regular text with
`finish_reason` propagated → response returns normally with
`finish_reason: "length"`

Confirmed with fresh local testing: 4 unclosed tool call flushes
triggered in a single session. Also confirmed via production logs from
Jan 29 (2 occurrences).

## Changes

1. **`parse_tool_calls` unclosed tool call flush** — when the generator
exhausts inside an open `<tool_call>` block, flush buffered tokens as
regular text and propagate `finish_reason`
2. **GLM regex fix** — match literal `\n` (not escaped `\\n`) between
arg tags; handle missing `</arg_value>` via lookahead
3. **7 new unit tests** for `parse_tool_calls` covering unclosed,
closed, passthrough, and failed-parse scenarios

## Why It Works

- `parse_tool_calls` now has a post-loop check: if `in_tool_call` is
still true, it yields the buffered text with the tracked `finish_reason`
instead of silently dropping it
- The GLM regex now matches real-world output where newlines appear
between tags and `</arg_value>` may be absent

## Test Plan

### Manual Testing
- Qwen3-0.6B-4bit with `tools` + various `max_tokens` values (61-75)
- Confirmed responses return with `finish_reason: "length"` instead of
hanging
- Log output shows `"generator exhausted inside unclosed tool call,
flushing buffered text"`

### Automated Testing
- 7 new tests in `test_parse_tool_calls.py`
- Full test suite passes (`uv run pytest`)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-02 17:45:51 +00:00
Evan Quiney
54b19879a0 create config home when checking for config file (#1353)
we didn't check before, raising a critical exception.
now we create ~/.config/exo on linux systems before touching config.toml.

this wasn't caught before since everything lives in ~/.exo on macos, and we no longer write the keypair to CONFIG_HOME, so config.toml has to do init work it avoided before.
2026-02-02 17:36:51 +00:00
ciaranbor
19965c7ba5 Ciaran/profiling (#1345)
## Motivation

Know what the hell is going on

## Changes

- Tracing library (src/exo/shared/tracing.py): trace() context manager,
Chrome Trace Format export, statistics computation
- Runner instrumentation
(src/exo/worker/engines/image/pipeline/runner.py): Wrapped sync/async
steps, compute blocks, and send/recv operations
- Trace collection: Workers send traces to master after task completion;
merged into ~/.exo/traces/trace_{task_id}.json
  - API endpoints: List, fetch, stats, and raw download at /v1/traces/*
  - Dashboard: Trace list and detail pages with Perfetto integration

## Why It Works

<img width="1236" height="767" alt="Screenshot 2026-01-30 at 19 00 09"
src="https://github.com/user-attachments/assets/73e6e46d-ba10-4e83-ba99-ff1c3f62ab05"
/>
<img width="1659" height="89" alt="Screenshot 2026-01-30 at 19 00 58"
src="https://github.com/user-attachments/assets/c0fd0e65-e4fc-4fd5-920d-b43b2887d591"
/>
2026-02-02 17:19:45 +00:00
Evan Quiney
3e27ead705 remove mdns discovered peers from appearing in state (#1312)
## motivation
eagerly discovered peers through gossipsub were added to state. this left things looking broken from one-sided connections

## changes
the worker no longer writes topology edges from these gossipsub messages
we now strictly rely on http-discovered topology, which tends to be more reflective of the actual state of the systems connectivity
2026-02-02 16:58:53 +00:00
Alex Cheema
d826d309b3 chore: gitignore hosts_*.json files (#1343)
## Motivation

`hosts_*.json` files are local host configuration snapshots that
shouldn't be tracked in version control.

## Changes

Added `hosts_*.json` pattern to `.gitignore`.

## Why It Works

The glob pattern `hosts_*.json` matches any file starting with `hosts_`
and ending with `.json` in the repo root.

## Test Plan

### Manual Testing
- Verified that `hosts_*.json` files are ignored by git after this
change.

### Automated Testing
- No automated tests needed for a `.gitignore` change.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 16:14:11 +00:00
Alex Cheema
c3537980bd feat: add Claude Messages API and OpenAI Responses API support (#1167)
## Motivation

Add support for Claude Messages API and OpenAI Responses API to allow
users to interact with exo using these popular API formats. This enables
broader compatibility with existing tooling and SDKs that expect these
API formats.

## Architecture

Adapter logic lives exclusively in the API layer
(`src/exo/master/adapters/`). On the way in, each adapter converts its
API-specific request type (`ChatCompletionRequest`,
`ClaudeMessagesRequest`, `ResponsesRequest`) into
`TextGenerationTaskParams`. On the way out, each adapter converts the
`TokenChunk` stream back into its API-specific response format.
Everything inside the application — commands, worker, runner, event
sourcing — only sees `TextGenerationTaskParams` and `TokenChunk`. No
API-specific types cross the boundary.

```
                          API layer                         │  Application internals
                                                            │
Chat Completions → [adapter] → TextGenerationTaskParams ──→ │ ──→ TextGeneration command → Runner → TokenChunk ──→ │ ──→ [adapter] → ChatCompletionResponse
Claude Messages  → [adapter] → TextGenerationTaskParams ──→ │ ──→ TextGeneration command → Runner → TokenChunk ──→ │ ──→ [adapter] → ClaudeMessagesResponse
Responses API    → [adapter] → TextGenerationTaskParams ──→ │ ──→ TextGeneration command → Runner → TokenChunk ──→ │ ──→ [adapter] → ResponsesResponse
```

## Changes

### New Files
- `src/exo/shared/types/claude_api.py` - Pydantic types for Claude
Messages API
- `src/exo/shared/types/openai_responses.py` - Pydantic types for OpenAI
Responses API
- `src/exo/shared/types/text_generation.py` - Shared
`TextGenerationTaskParams` internal type
- `src/exo/master/adapters/chat_completions.py` - Chat Completions
adapter (streaming/non-streaming)
- `src/exo/master/adapters/claude.py` - Claude Messages adapter
(streaming/non-streaming)
- `src/exo/master/adapters/responses.py` - OpenAI Responses adapter
(streaming/non-streaming)

### Modified Files
- `src/exo/master/api.py` - Refactored to use adapters uniformly for all
endpoints; extracted `_resolve_and_validate_text_model` helper to
deduplicate model validation across all text endpoints; removed ad-hoc
`try/except ValueError` blocks from non-streaming paths

### New Endpoints
- `POST /v1/messages` - Claude Messages API (streaming and
non-streaming)
- `POST /v1/responses` - OpenAI Responses API (streaming and
non-streaming)

## Why It Works

All APIs are implemented as pure conversion adapters at the edge of the
application:
1. Adapter functions in `src/exo/master/adapters/` convert incoming
requests to `TextGenerationTaskParams`
2. `api.py` wraps the params in a `TextGeneration` command and sends it
through the existing command/event flow
3. The worker, runner, and event sourcing layers only handle
`TextGenerationTaskParams` and `TokenChunk` — they have no awareness of
Chat Completions, Claude, or Responses API formats
4. On response, adapter functions convert the `TokenChunk` stream back
to the caller's expected format
5. Model validation is handled by a single shared helper
(`_resolve_and_validate_text_model`), mirroring the existing
`_validate_image_model` pattern for image endpoints

No changes to core inference logic were needed.

### Streaming Formats
- **Chat Completions**: Uses `data: {...}\n\n` with `[DONE]` terminator
- **Claude**: Uses event types `message_start`, `content_block_start`,
`content_block_delta`, `content_block_stop`, `message_delta`,
`message_stop`
- **OpenAI Responses**: Uses event types `response.created`,
`response.in_progress`, `response.output_item.added`,
`response.content_part.added`, `response.output_text.delta`,
`response.output_text.done`, `response.content_part.done`,
`response.output_item.done`, `response.completed`

## Test Plan

### Manual Testing
Hardware: MacBook Pro M3 Max

**Non-streaming tests:**
```bash
# Chat Completions API
curl -X POST http://localhost:52415/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 20}'

# Claude Messages API
curl -X POST http://localhost:52415/v1/messages \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "max_tokens": 50, "messages": [{"role": "user", "content": "Hello"}]}'

# OpenAI Responses API
curl -X POST http://localhost:52415/v1/responses \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "input": "Hello", "max_output_tokens": 20}'
```

**Streaming tests:**
```bash
# Chat Completions API (streaming)
curl -N -X POST http://localhost:52415/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}], "stream": true, "max_tokens": 20}'

# Claude Messages API (streaming)
curl -N -X POST http://localhost:52415/v1/messages \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "max_tokens": 50, "messages": [{"role": "user", "content": "Hello"}], "stream": true}'

# OpenAI Responses API (streaming)
curl -N -X POST http://localhost:52415/v1/responses \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.2-1b", "input": "Hello", "stream": true, "max_output_tokens": 20}'
```

All endpoints tested successfully with proper response formats and
streaming events.

### Automated Testing
- Tests in `src/exo/master/tests/` all pass (85 tests)
- Type checker (basedpyright) passes with 0 errors
- Linter (ruff) passes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
2026-02-02 15:58:37 +00:00
rltakashige
21d477f1cb Update exo bench (#1357)
## Motivation

Make exo bench faster for longer prompts, lengthen default timeouts and
use pairs for pp and tg.

## Changes

- Uses binary search to find correct prompt
- Flag to force all combinations if that is desired
2026-02-02 15:46:15 +00:00
Jake Hillion
b2579c78fe nix: add macmon to PATH in wrapper scripts on Darwin
`nix run .#exo` couldn't find `macmon` because the Nix wrapper scripts
didn't include it in PATH, causing `shutil.which("macmon")` to fail.

Added `--prefix PATH : ${pkgs.macmon}/bin` to the `makeWrapper` call,
conditional on Darwin via `lib.optionalString`, so macmon's binary is
available at runtime without modifying the user's system PATH.

Test plan:
- Verified `nix build .#exo` succeeds
- Checked wrapper script contains macmon store path in PATH prefix
2026-02-02 13:42:36 +00:00
Evan Quiney
cd946742f7 fix skipping logic in worker plan (#1342)
the worker plan function had some skipping logic missing, leading to
double-submitting tasks.
2026-01-30 14:31:40 +00:00
rltakashige
a5bc38ad1f Check all nodes to evict (#1341)
## Motivation

If nodes have uneven memory, one node may evict cache that remains on
another node. This will break prefill on some setups.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-30 13:42:09 +00:00
Evan Quiney
2a4e0d4629 make node-ids unique per-session (#1338)
we currently have no strict reuqirements that node ids persist across
sessions, so we can generate fresh nodeids each time

this avoids issues like #1332, but prevents further features such as
caching downloads or node-id dialling

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-30 13:33:31 +00:00
Evan Quiney
46a14153dd switch to ModelCard.load outside of download log (#1339)
some attempts to load model cards (i.e. build_base_shard) always went
through networking rather than using downloaded model cards. we should
always default to ModelCard.load in these scenarios
2026-01-30 11:20:20 +00:00
Evan
9ba61f3733 improve log message in shard downloader
closes #1336
2026-01-30 10:35:01 +00:00
rltakashige
d9eca75895 Add usage stats (#1333)
## Motivation

(Probably) the final missing piece of the Chat Completions API 

## Changes

Add UsageStats 

## Why It Works

OpenCode reviewed my PR and gave me stats:

<img width="1150" height="802" alt="image"
src="https://github.com/user-attachments/assets/ebc06bae-797f-4087-87d5-2f26cf60fc48"
/>


## Test Plan

### Automated Testing
No tests were broken.
2026-01-30 10:23:08 +00:00
rltakashige
9dabde7e57 Fix bench after recent updates (#1331)
## Motivation

A lot of changes happened without much attention to the state of exo
bench.

## Changes

Use TaggedModel for BenchChatCompletion so it serialises properly.
Don't break after gpt oss tool call to preserve parity with the rest of
the codebase.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<img width="2856" height="678" alt="image"
src="https://github.com/user-attachments/assets/2e18cf0d-c0f8-467c-9763-1a6a59c8a327"
/>

Also tested GPT OSS tool calling in OpenCode
2026-01-29 19:14:40 +00:00
ciaranbor
a31942ce12 Ciaran/image non streaming (#1328)
## Motivation

The dashboard UI attempted to parse all image generation responses as
SSE streams, even when streaming was disabled. This broke non-streaming
image generation.

## Changes

- Parse JSON responses directly when not streaming, use SSE parser only
when stream=true AND partialImages > 0
- explicitly disable partial images when not streaming

## Why It Works

Both API and dashboard now use the same condition (stream &&
partialImages > 0) to determine response format, ensuring correct
parsing.

## Test Plan

### Manual Testing

Non-streamed image generation results appear in the UI. Streamed image
generation still works
2026-01-29 17:24:32 +00:00
Alex Cheema
7cc313b22a Treat Swift/Xcode build warnings as errors (#1322)
## Motivation

Warnings that go unchecked tend to accumulate and hide real issues.
Treating them as errors ensures they are addressed immediately, both
locally during development and in CI.

## Changes

Added `SWIFT_TREAT_WARNINGS_AS_ERRORS = YES` and
`GCC_TREAT_WARNINGS_AS_ERRORS = YES` to the **project-level** Debug and
Release build configurations in `project.pbxproj`. This applies to all
targets (EXO, EXOTests, EXOUITests).

## Why It Works

Xcode's `SWIFT_TREAT_WARNINGS_AS_ERRORS` and
`GCC_TREAT_WARNINGS_AS_ERRORS` build settings promote Swift and C/ObjC
warnings to errors at compile time. Setting them at the project level
means all targets inherit the policy without needing per-target or
CI-level overrides.

## Test Plan

### Manual Testing
- Built the EXO scheme in Release configuration with `xcodebuild` — no
warning-as-error failures from Swift or C/ObjC sources.

### Automated Testing
- CI already builds with `-configuration Release`, so it will
automatically enforce warnings-as-errors via the inherited project
settings — no CI changes needed.
2026-01-29 17:15:49 +00:00
rltakashige
2837225dc7 Load pipeline layers sequentially (#1329)
## Motivation

Slightly annoyed by needing this change, but same story as for tensor
loading...
2026-01-29 17:08:38 +00:00
Jake Hillion
e4c6a7dbb4 nix: add Python packaging with uv2nix
Add uv2nix to build Python packages from uv.lock. This creates a fully
Nix-managed Python environment with the Rust bindings injected via overlay.

Changes:
- Add pyproject-nix, uv2nix, and pyproject-build-systems flake inputs
- Create python/parts.nix with overlays to inject Nix-built Rust wheel
- Export packages.exo on macOS (wraps exo/exo-master/exo-worker with dashboard)
- Add checks.lint (ruff, all platforms) and checks.pytest (macOS only)
- Simplify CI typecheck job using nicknovitski/nix-develop action
- Delete .github/actions/typecheck composite action (no longer needed)
- Add no-build-package for MLX packages in pyproject.toml (use wheels)

The Python build is currently macOS-only since MLX requires Metal. Linux
support will be added once the pyproject dependencies are simplified.

Test plan:
- Run `nix flake check` on macOS to verify pytest and lint pass
- Build exo package on macOS: `nix build .#exo`
- Verify CI pipeline passes with simplified typecheck job
2026-01-29 16:35:58 +00:00
Evan
b1e88a3d06 shfmt
adds shfmt, a shell formatter, and formats the bash files
2026-01-29 15:24:36 +00:00
Jake Hillion
ebeddfb308 mlx: build with Nix (#1285)
In order to make testing and deployment simpler and more reproducible,
we want to provide a Nix derivation for our macOS .app build. We already
build the Rust and dashboard with Nix, but so far the Python has been
blocked because we haven't had an MLX build.

This change adds a Metal compiler derivation that uses `requireFile` to
be provided a NAR of the unfree macOS Metal compiler. It is documented
how to get this file, but effectively you have to trigger the download,
mount the DMG, and NAR the result. Once this is added to the store by
hash we can build MLX using it. The MLX build itself is quite self
explanatory.

Test plan:
- CI. We follow the instructions to grab the Metal compiler. Once this
is in Cachix we should really never do this again, and I can pin the
path too to ensure it doesn't leave.
- MLX tests run as part of the MLX derivation's build. They pass.
- `NIXPKGS_ALLOW_UNFREE=1 nix build .#mlx.passthru.tests.mlxTest
--impure --option sandbox false`

---------

Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-29 14:07:00 +00:00
Alex Cheema
9111575997 Add startup delay and update network setup message (#1309)
## Summary
- Add 20-second startup delay to wait for macOS to finish network setup
after boot
- Update user-facing message to clarify the service configures local
networking, disables Thunderbolt Bridge (preventing packet storms), and
installs a Network Location

## Test plan
- [ ] Manual verification of Swift syntax
- [ ] Test network setup on macOS device after reboot

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-29 13:05:50 +00:00
Sami Khan
ffacabe7e4 Fix uninstall button error (#1306)
## Motivation

Fix "Network setup script failed" error when clicking uninstall button
and resolve Xcode compiler warnings.

## Changes

- NetworkSetupHelper.swift: Add || true guards and explicit return 0 in
find_and_enable_thunderbolt_bridge to prevent script failures with set
-euo pipefail
- ThunderboltBridgeService.swift: Use withCString and
withUnsafeMutablePointer for Authorization API calls to fix pointer
lifetime warnings
- EXOApp.swift: Mark showNotification as nonisolated to fix main actor
isolation warning

## Why It Works

- The uninstall script's Thunderbolt re-enable function could exit
non-zero in edge cases (no bridges, no matches). Since this is a cleanup
step, failures should not abort uninstall.
- Swift requires explicit pointer lifetime management when passing
strings/structs to C APIs.
- showNotification is called from a nonisolated delegate method and uses
thread-safe APIs.

## Test Plan

### Manual Testing
Hardware: MacBook Pro

- Clicked Uninstall button, verified it completes without error
- Built in Xcode, verified no warnings   

### Automated Testing
N/A
2026-01-29 12:57:48 +00:00
rltakashige
9e58a57599 Add RDMA caveats to README.md (#1316)
## Motivation

Running RDMA from source is not well documented as is. Several
surprising things that took time to debug internally too.

App should be updated to detect MacOS versions in future.
2026-01-28 18:44:00 +00:00
Evan Quiney
748a026071 fix configdata validation for kimi-k2 (#1314)
## motivation
our shard downloader could not correctly fetch data for kimi-k2, as it
deferred some values to a text_config field.
## changes
config_data now prioritizes this field if it exists in information like
layer_count
2026-01-28 14:29:36 +00:00
Alex Cheema
f1a2d054ec Update tagline to "Run frontier AI locally" (#1313)
- Update README tagline from "Run your own AI cluster at home with
everyday devices" to "Run frontier AI locally"
2026-01-28 12:38:14 +00:00
Alex Cheema
b3c8f85fc8 Update MLX to 0.30.4 (#1311)
## Summary
- Bump mlx from 0.30.3 to 0.30.4

## Test plan
- [x] `uv lock` succeeds
- [x] Type checking passes (`uv run basedpyright`)
- [x] Run inference tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 04:30:21 -08:00
rltakashige
a562114ba5 Add Kimi K2.5 support (#1302)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-28 05:44:19 +00:00
Evan Quiney
991d278119 replace nix fmt with treefmt in just lint (#1301)
man evaluating the nix flake is so slow. treefmt speeeedy
2026-01-27 17:03:01 +00:00
133 changed files with 7803 additions and 3060 deletions

View File

@@ -1,12 +0,0 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,73 +26,14 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
nix:
name: Build and check (${{ matrix.system }})
@@ -123,6 +64,63 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -134,3 +132,14 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

3
.gitignore vendored
View File

@@ -28,3 +28,6 @@ target/
dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# host config snapshots
hosts_*.json

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
@@ -107,6 +107,10 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
### Run from Source (Linux)
**Prerequisites:**
@@ -230,7 +234,7 @@ This removes:
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
Please refer to the caveats for immediate troubleshooting.
To enable RDMA on macOS, follow these steps:
@@ -247,6 +251,14 @@ To enable RDMA on macOS, follow these steps:
After that, RDMA will be enabled in macOS and exo will take care of the rest.
**Important Caveats**
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
2. The cables must support TB5.
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
---
### Using the API

View File

@@ -342,6 +342,8 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -397,6 +399,8 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
private func showNotification(title: String, body: String) {
nonisolated private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -293,7 +293,7 @@ struct ClusterTask {
let modelName: String?
let promptPreview: String?
let errorMessage: String?
let parameters: ChatCompletionTaskParameters?
let parameters: TextGenerationTaskParameters?
var sortPriority: Int {
switch status {
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
let taskStatus: TaskStatus?
let instanceId: String?
let commandId: String?
let taskParams: ChatCompletionTaskParameters?
let taskParams: TextGenerationTaskParameters?
let errorType: String?
let errorMessage: String?
}
struct ChatCompletionTaskParameters: Decodable, Equatable {
struct TextGenerationTaskParameters: Decodable, Equatable {
let model: String?
let messages: [ChatCompletionMessage]?
let maxTokens: Int?
@@ -374,7 +374,7 @@ extension ClusterTask {
guard let id = payload.taskId else { return nil }
let status = payload.taskStatus ?? .unknown
switch kindKey {
case "ChatCompletion":
case "TextGeneration":
self.init(
id: id,
status: status,

View File

@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
@@ -241,11 +244,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic 2>/dev/null || true
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
# Delete the exo network location if it exists
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -255,12 +258,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
')
') || true
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -269,7 +272,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
')
') || true
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -277,8 +280,9 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge
find_and_enable_thunderbolt_bridge || true
echo "EXO network components removed successfully"
"""

View File

@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: ChatCompletionTaskParameters?
let parameters: TextGenerationTaskParameters?
var title: String {
switch kind {

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,4 +151,3 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -5,10 +5,13 @@ from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
@@ -16,6 +19,84 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
def load_tokenizer_for_bench(model_id: str) -> Any:
"""
Load tokenizer for benchmarking, with special handling for Kimi models.
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
This function replicates the logic from utils_mlx.py for bench compatibility.
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
import importlib.util
import types
from huggingface_hub import snapshot_download
# Download/get the model path
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken"],
)
)
sys.path.insert(0, str(model_path))
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
hf_tokenizer.encode = _patched_encode
return hf_tokenizer
# Default: use AutoTokenizer
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -24,7 +105,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -180,14 +261,7 @@ def parse_int_list(values: list[str]) -> list[int]:
part = part.strip()
if part:
items.append(int(part))
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
@@ -240,7 +314,11 @@ def run_one_completion(
stats = out.get("generation_stats")
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
# Extract preview, handling None content (common for thinking models)
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
return {
"elapsed_s": elapsed,
@@ -277,12 +355,29 @@ class PromptSizer:
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
content = ""
tok = self.count_fn(content)
# Estimate tokens per atom using a sample
sample_count = 100
sample_content = self.atom * sample_count
sample_tokens = self.count_fn(sample_content) - self.base_tokens
tokens_per_atom = sample_tokens / sample_count
while tok < target:
content += self.atom
tok = self.count_fn(content)
# Estimate starting point
needed_tokens = target - self.base_tokens
estimated_atoms = int(needed_tokens / tokens_per_atom)
# Binary search to find exact atom count
low, high = 0, estimated_atoms * 2 + 100
while low < high:
mid = (low + high) // 2
tok = self.count_fn(self.atom * mid)
if tok < target:
low = mid + 1
else:
high = mid
content = self.atom * low
tok = self.count_fn(content)
logger.info(f"{tok=}")
if tok != target:
raise RuntimeError(
@@ -348,7 +443,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -358,6 +453,11 @@ def main() -> int:
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--all-combinations",
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -369,6 +469,15 @@ def main() -> int:
logger.error("--repeat must be >= 1")
return 2
# Log pairing mode
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
logger.info(
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
)
else:
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
@@ -377,10 +486,7 @@ def main() -> int:
)
previews = previews_resp.get("previews") or []
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
tokenizer = load_tokenizer_for_bench(full_model_id)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -486,60 +592,55 @@ def main() -> int:
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
# If pp and tg lists have same length, run in tandem (zip)
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
if use_combinations:
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
else:
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
for pp, tg in pp_tg_pairs:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")

View File

@@ -865,7 +865,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -905,7 +904,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1522,7 +1520,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1532,7 +1529,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1945,7 +1941,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2653,7 +2648,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2696,7 +2690,6 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2869,7 +2862,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3014,7 +3006,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3036,7 +3027,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -173,6 +173,41 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
// Trace API response types
export interface TraceCategoryStats {
totalUs: number;
count: number;
minUs: number;
maxUs: number;
avgUs: number;
}
export interface TraceRankStats {
byCategory: Record<string, TraceCategoryStats>;
}
export interface TraceStatsResponse {
taskId: string;
totalWallTimeUs: number;
byCategory: Record<string, TraceCategoryStats>;
byRank: Record<number, TraceRankStats>;
}
export interface TraceListItem {
taskId: string;
createdAt: string;
fileSize: number;
}
export interface TraceListResponse {
traces: TraceListItem[];
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2095,107 +2130,137 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const numImages = params.numImages;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const numImages = params.numImages;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
},
);
}
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2343,69 +2408,98 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
},
);
}
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(
@@ -2491,6 +2585,49 @@ class AppStore {
throw error;
}
}
/**
* List all available traces
*/
async listTraces(): Promise<TraceListResponse> {
const response = await fetch("/v1/traces");
if (!response.ok) {
throw new Error(`Failed to list traces: ${response.status}`);
}
return (await response.json()) as TraceListResponse;
}
/**
* Check if a trace exists for a given task ID
*/
async checkTraceExists(taskId: string): Promise<boolean> {
try {
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
return response.ok;
} catch {
return false;
}
}
/**
* Get computed statistics for a task's trace
*/
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
const response = await fetch(
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
);
if (!response.ok) {
throw new Error(`Failed to fetch trace stats: ${response.status}`);
}
return (await response.json()) as TraceStatsResponse;
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
getTraceRawUrl(taskId: string): string {
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
}
}
export const appStore = new AppStore();
@@ -2602,3 +2739,12 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);
// Trace actions
export const listTraces = () => appStore.listTraces();
export const checkTraceExists = (taskId: string) =>
appStore.checkTraceExists(taskId);
export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);

View File

@@ -0,0 +1,190 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function downloadTrace(taskId: string) {
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={() => downloadTrace(trace.taskId)}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -0,0 +1,367 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function downloadTrace() {
if (!taskId) return;
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
onclick={downloadTrace}
disabled={loading || !!error}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

65
flake.lock generated
View File

@@ -21,7 +21,9 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1765953015,
@@ -149,19 +151,44 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"type": "github"
},
"original": {
@@ -178,7 +205,10 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
}
},
"rust-analyzer-src": {
@@ -239,6 +269,29 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,6 +24,26 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -48,6 +68,7 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -58,6 +79,11 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -79,14 +105,24 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,7 +1,7 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt
treefmt || nix fmt
lint:
uv run ruff check --fix

View File

@@ -0,0 +1,79 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

56
nix/metal-toolchain.nix Normal file
View File

@@ -0,0 +1,56 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

158
nix/mlx.nix Normal file
View File

@@ -0,0 +1,158 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm==0.30.5",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -63,6 +63,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

94
python/parts.nix Normal file
View File

@@ -0,0 +1,94 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 405874409472

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 765577920512

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 122406567936

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 229780750336

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 198556925568

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 286737579648

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 396963397248

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 19327352832

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 22548578304

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 26843545600

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 620622774272

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 706522120192

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 662498705408

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 729808896

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 1863319552

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 3501195264

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 76799803392

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 4637851648

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 8954839040

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 16882073600

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 100086644736

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 342884352

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 698351616

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 141733920768

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 268435456000

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 17612931072

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 33279705088

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 289910292480

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 579820584960

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 46976204800

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 47080074240

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 70652212224

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 12025908224

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 144383672320

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -160,15 +160,14 @@ class ResumableShardDownloader(ShardDownloader):
# Kick off download status coroutines concurrently
tasks = [
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
for model_card in await get_model_cards()
]
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
logger.warning(f"Error downloading shard: {type(e).__name__}")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -90,7 +90,6 @@ class Node:
worker = Worker(
node_id,
session_id,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
@@ -227,9 +226,6 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,214 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionRequest,
ChatCompletionResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def chat_request_to_text_generation(
request: ChatCompletionRequest,
) -> TextGenerationTaskParams:
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
else:
# Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
continue
if msg.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build full message dict for chat template (preserves tool_calls etc.)
# Normalize content for model_dump
msg_copy = msg.model_copy(update={"content": content})
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
chat_template_messages.append(dumped)
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
finish_reason=finish_reason,
)
],
)

View File

@@ -0,0 +1,323 @@
"""Claude Messages API adapter for converting requests/responses."""
import json
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeInputJsonDelta,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeToolResultBlock,
ClaudeToolUseBlock,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
"""Extract plain text from a tool_result content field."""
if block.content is None:
return ""
if isinstance(block.content, str):
return block.content
return "".join(sub_block.text for sub_block in block.content)
def claude_request_to_text_generation(
request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams:
# Handle system message
instructions: str | None = None
chat_template_messages: list[dict[str, Any]] = []
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
instructions = "".join(block.text for block in request.system)
chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input
input_messages: list[InputMessage] = []
for msg in request.messages:
if isinstance(msg.content, str):
input_messages.append(InputMessage(role=msg.role, content=msg.content))
chat_template_messages.append({"role": msg.role, "content": msg.content})
continue
# Process structured content blocks
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(block.input),
},
}
)
elif isinstance(block, ClaudeToolResultBlock):
tool_results.append(block)
content = "".join(text_parts)
# Build InputMessage from text content
if msg.role in ("user", "assistant"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build chat_template_messages preserving tool structure
if tool_calls:
chat_template_messages.append(
{"role": "assistant", "content": content, "tool_calls": tool_calls}
)
elif tool_results:
for tr in tool_results:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_use_id,
"content": _extract_tool_result_text(tr),
}
)
else:
chat_template_messages.append({"role": msg.role, "content": content})
# Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None
if request.tools:
tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.input_schema,
},
}
for tool in request.tools
]
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
tools=tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
ClaudeToolUseBlock(
id=f"toolu_{uuid4().hex[:24]}",
name=tool.name,
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Build content blocks
content: list[ClaudeContentBlock] = []
if combined_text:
content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks)
# If no content at all, include empty text block
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=content,
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
for tool in chunk.tool_calls:
tool_id = f"toolu_{uuid4().hex[:24]}"
tool_input_json = tool.arguments
# content_block_start for tool_use
tool_block_start = ClaudeContentBlockStartEvent(
index=next_block_index,
content_block=ClaudeToolUseBlock(
id=tool_id, name=tool.name, input={}
),
)
yield f"event: content_block_start\ndata: {tool_block_start.model_dump_json()}\n\n"
# content_block_delta with input_json_delta
tool_delta_event = ClaudeContentBlockDeltaEvent(
index=next_block_index,
delta=ClaudeInputJsonDelta(partial_json=tool_input_json),
)
yield f"event: content_block_delta\ndata: {tool_delta_event.model_dump_json()}\n\n"
# content_block_stop
tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)
yield f"event: content_block_stop\ndata: {tool_block_stop.model_dump_json()}\n\n"
next_block_index += 1
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,373 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from uuid import uuid4
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
ResponseCompletedEvent,
ResponseContentPart,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionCallItem,
ResponseInProgressEvent,
ResponseInputMessage,
ResponseItem,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
return content
return "".join(part.text for part in content)
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = [InputMessage(role="user", content=request.input)]
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
if request.instructions is not None:
chat_template_messages.append(
{"role": "system", "content": request.instructions}
)
for item in request.input:
if isinstance(item, ResponseInputMessage):
content = _extract_content(item.content)
if item.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=item.role, content=content))
if item.role == "system":
chat_template_messages.append(
{"role": "system", "content": content}
)
else:
chat_template_messages.append(
{"role": item.role, "content": content}
)
elif isinstance(item, FunctionCallInputItem):
chat_template_messages.append(
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": item.call_id,
"type": "function",
"function": {
"name": item.name,
"arguments": item.arguments,
},
}
],
}
)
else:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": item.call_id,
"content": item.output,
}
)
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(
model=request.model,
input=input_value,
instructions=request.instructions,
max_output_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
tools=request.tools,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
chat_template_messages=built_chat_template or request.chat_template_messages,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{uuid4().hex[:24]}",
call_id=f"call_{uuid4().hex[:24]}",
name=tool.name,
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output: list[ResponseItem] = [
ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
]
output.extend(function_call_items)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
seq = count(1)
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{uuid4().hex[:24]}"
call_id = f"call_{uuid4().hex[:24]}"
# response.output_item.added for function_call
fc_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments="",
status="in_progress",
)
fc_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments=tool.arguments,
status="completed",
)
fc_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item)
next_output_index += 1
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_message_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
output: list[ResponseItem] = [final_message_item]
output.extend(function_call_items)
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -2,8 +2,10 @@ import base64
import contextlib
import json
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import datetime, timezone
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Literal, cast
from uuid import uuid4
@@ -19,28 +21,47 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.chat_completions import (
chat_request_to_text_generation,
collect_chat_response,
generate_chat_stream,
)
from exo.master.adapters.claude import (
claude_request_to_text_generation,
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
responses_request_to_text_generation,
)
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
)
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
get_model_cards,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionRequest,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
BenchImageGenerationTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
@@ -51,7 +72,7 @@ from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageEditsTaskParams,
ImageGenerationResponse,
ImageGenerationStats,
ImageGenerationTaskParams,
@@ -64,8 +85,14 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
TraceCategoryStats,
TraceEventResponse,
TraceListItem,
TraceListResponse,
TraceRankStats,
TraceResponse,
TraceStatsResponse,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -74,8 +101,11 @@ from exo.shared.types.chunks import (
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
@@ -89,6 +119,7 @@ from exo.shared.types.commands import (
SendInputChunk,
StartDownload,
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
@@ -96,15 +127,18 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -112,47 +146,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
finish_reason=chunk.finish_reason,
)
],
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -183,6 +176,15 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
@self.app.middleware("http")
async def _log_requests( # pyright: ignore[reportUnusedFunction]
request: Request,
call_next: Callable[[Request], Awaitable[StreamingResponse]],
) -> StreamingResponse:
logger.debug(f"API request: {request.method} {request.url.path}")
return await call_next(request)
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -190,13 +192,13 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=find_dashboard(),
directory=DASHBOARD_DIR,
html=True,
),
name="dashboard",
)
self._chat_completion_queues: dict[
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
@@ -210,7 +212,7 @@ class API:
self.state = State()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._text_generation_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
@@ -267,14 +269,20 @@ class API:
self.app.post("/bench/images/edits")(self.bench_image_edits)
self.app.get("/images")(self.list_images)
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -291,7 +299,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
model_card = await ModelCard.load(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -319,7 +327,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_card = await ModelCard.load(model_id)
try:
placements = get_instance_placements(
@@ -361,10 +369,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
model_card = await ModelCard.load(model_id)
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -379,96 +384,93 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
instance=None,
error=str(exc),
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -491,13 +493,15 @@ class API:
instance_id=instance_id,
)
async def _chat_chunk_stream(
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._chat_completion_queues[command_id], recv = channel[
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
]()
@@ -518,105 +522,20 @@ class API:
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
],
)
async def _collect_chat_completion_with_stats(
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -655,7 +574,9 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text, tool_calls=tool_calls
role="assistant",
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
finish_reason=finish_reason,
)
@@ -664,68 +585,77 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
async def chat_completions(
self, payload: ChatCompletionTaskParams
self, payload: ChatCompletionRequest
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ChatCompletion(
request_params=payload,
"""OpenAI Chat Completions API - adapter."""
task_params = chat_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
self, payload: BenchChatCompletionRequest
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
task_params = chat_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
task_params = task_params.model_copy(update={"stream": False, "bench": True})
payload.stream = False
command = ChatCompletion(request_params=payload)
command = TextGeneration(task_params=task_params)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
if not any(
instance.shard_assignments.model_id == model_id
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(model_id)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {model_id}",
)
return model_id
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -771,10 +701,10 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
command = ImageGeneration(
request_params=payload,
task_params=payload,
)
await self._send(command)
@@ -1016,13 +946,13 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
command = ImageGeneration(
request_params=payload,
task_params=payload,
)
await self._send(command)
@@ -1037,7 +967,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: str,
model: ModelId,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1064,7 +994,7 @@ class API:
total_chunks = len(data_chunks)
command = ImageEdits(
request_params=ImageEditsInternalParams(
task_params=ImageEditsTaskParams(
image_data="",
total_input_chunks=total_chunks,
prompt=prompt,
@@ -1132,7 +1062,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1188,7 +1118,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1208,6 +1138,62 @@ class API:
response_format=response_format,
)
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Claude Messages API - adapter."""
task_params = claude_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""OpenAI Responses API."""
task_params = responses_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(task_params.model)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1231,7 +1217,7 @@ class API:
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
for card in await get_model_cards()
]
)
@@ -1280,14 +1266,32 @@ class API:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._chat_completion_queues.get(
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
def _save_merged_trace(self, event: TracesMerged) -> None:
traces = [
TraceEvent(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in event.traces
]
output_path = EXO_TRACING_CACHE_DIR / f"trace_{event.task_id}.json"
export_trace(traces, output_path)
logger.debug(f"Saved merged trace to {output_path}")
async def _pause_on_new_election(self):
with self.election_receiver as ems:
@@ -1335,3 +1339,103 @@ class API:
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
def _get_trace_path(self, task_id: str) -> Path:
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
async def list_traces(self) -> TraceListResponse:
traces: list[TraceListItem] = []
for trace_file in sorted(
EXO_TRACING_CACHE_DIR.glob("trace_*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
# Extract task_id from filename (trace_{task_id}.json)
task_id = trace_file.stem.removeprefix("trace_")
stat = trace_file.stat()
created_at = datetime.fromtimestamp(
stat.st_mtime, tz=timezone.utc
).isoformat()
traces.append(
TraceListItem(
task_id=task_id,
created_at=created_at,
file_size=stat.st_size,
)
)
return TraceListResponse(traces=traces)
async def get_trace(self, task_id: str) -> TraceResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
return TraceResponse(
task_id=task_id,
traces=[
TraceEventResponse(
name=event.name,
start_us=event.start_us,
duration_us=event.duration_us,
rank=event.rank,
category=event.category,
)
for event in trace_events
],
)
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
stats = compute_stats(trace_events)
return TraceStatsResponse(
task_id=task_id,
total_wall_time_us=stats.total_wall_time_us,
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in stats.by_category.items()
},
by_rank={
rank: TraceRankStats(
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in rank_stats.items()
}
)
for rank, rank_stats in stats.by_rank.items()
},
)
async def get_trace_raw(self, task_id: str) -> FileResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
return FileResponse(
path=trace_path,
media_type="application/json",
filename=f"trace_{task_id}.json",
)

View File

@@ -11,8 +11,8 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
@@ -23,6 +23,7 @@ from exo.shared.types.commands import (
SendInputChunk,
TaskFinished,
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
@@ -35,11 +36,11 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
TracesCollected,
TracesMerged,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
@@ -50,6 +51,9 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -86,6 +90,8 @@ class Master:
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
async def run(self):
logger.info("Starting Master")
@@ -117,11 +123,11 @@ class Master:
match command:
case TestCommand():
pass
case ChatCompletion():
case TextGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
== command.task_params.model
):
task_count = sum(
1
@@ -134,7 +140,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
f"No instance found for model {command.task_params.model}"
)
available_instance_ids = sorted(
@@ -148,12 +154,12 @@ class Master:
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task=TextGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
task_params=command.task_params,
),
)
)
@@ -163,7 +169,7 @@ class Master:
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
== command.task_params.model
):
task_count = sum(
1
@@ -176,7 +182,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
f"No instance found for model {command.task_params.model}"
)
available_instance_ids = sorted(
@@ -187,25 +193,37 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
task_params=command.task_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
== command.task_params.model
):
task_count = sum(
1
@@ -218,7 +236,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
f"No instance found for model {command.task_params.model}"
)
available_instance_ids = sorted(
@@ -229,20 +247,32 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
task_params=command.task_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
@@ -335,6 +365,10 @@ class Master:
local_event.origin,
)
for event in self._multi_buffer.drain():
if isinstance(event, TracesCollected):
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
@@ -373,3 +407,29 @@ class Master:
event=event.event,
)
)
async def _handle_traces_collected(self, event: TracesCollected) -> None:
task_id = event.task_id
if task_id not in self._pending_traces:
self._pending_traces[task_id] = {}
self._pending_traces[task_id][event.rank] = event.traces
if (
task_id in self._expected_ranks
and set(self._pending_traces[task_id].keys())
>= self._expected_ranks[task_id]
):
await self._merge_and_save_traces(task_id)
async def _merge_and_save_traces(self, task_id: TaskId) -> None:
all_trace_data: list[TraceEventData] = []
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]

View File

@@ -0,0 +1,182 @@
"""Tests for Claude Messages API conversion functions and types."""
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_text_generation,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeMessage,
ClaudeMessagesRequest,
ClaudeTextBlock,
)
from exo.shared.types.common import ModelId
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to TextGenerationTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_text_generation(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)

View File

@@ -0,0 +1,265 @@
"""Tests for Claude Messages API tool_use support in the adapter."""
import json
from collections.abc import AsyncGenerator
from typing import Any, cast
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
from exo.shared.types.api import ToolCallItem
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, ModelId
async def _chunks_to_stream(
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
for chunk in chunks:
yield chunk
MODEL = ModelId("test-model")
COMMAND_ID = CommandId("cmd_test123")
def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
"""Parse SSE event strings into JSON dicts."""
parsed: list[dict[str, Any]] = []
for event_str in events:
for line in event_str.strip().split("\n"):
if line.startswith("data: "):
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
return parsed
class TestCollectClaudeResponseToolUse:
"""Tests for non-streaming tool_use response collection."""
async def test_tool_call_chunk_produces_tool_use_blocks(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "San Francisco"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 1
block = tool_blocks[0]
assert block.type == "tool_use"
assert block.name == "get_weather"
assert block.input == {"location": "San Francisco"}
assert block.id.startswith("toolu_")
async def test_multiple_tool_calls(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
),
ToolCallItem(
name="get_time",
arguments='{"timezone": "PST"}',
),
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 2
assert tool_blocks[0].name == "get_weather"
assert tool_blocks[1].name == "get_time"
async def test_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Let me check ", token_id=1, usage=None),
TokenChunk(model=MODEL, text="the weather.", token_id=2, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "NYC"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
text_blocks = [b for b in response.content if b.type == "text"]
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(text_blocks) == 1
assert text_blocks[0].text == "Let me check the weather."
assert len(tool_blocks) == 1
assert tool_blocks[0].name == "get_weather"
async def test_no_content_produces_empty_text_block(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert len(response.content) == 1
assert response.content[0].type == "text"
class TestGenerateClaudeStreamToolUse:
"""Tests for streaming tool_use event generation."""
async def test_tool_call_emits_tool_use_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Find tool_use content_block_start
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
content_block = cast(dict[str, Any], tool_starts[0]["content_block"])
assert content_block["name"] == "get_weather"
assert content_block["input"] == {}
assert cast(str, content_block["id"]).startswith("toolu_")
# Find input_json_delta
json_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type")
== "input_json_delta"
]
assert len(json_deltas) == 1
delta = cast(dict[str, Any], json_deltas[0]["delta"])
assert json.loads(cast(str, delta["partial_json"])) == {"location": "SF"}
# Find message_delta with tool_use stop reason
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert len(msg_deltas) == 1
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Hello ", token_id=1, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="search",
arguments='{"query": "test"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Should have text delta at index 0
text_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type") == "text_delta"
]
assert len(text_deltas) == 1
assert text_deltas[0]["index"] == 0
assert cast(dict[str, Any], text_deltas[0]["delta"])["text"] == "Hello "
# Tool block at index 1
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
assert tool_starts[0]["index"] == 1
# Stop reason should be tool_use
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_tool_block_stop_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(name="fn1", arguments="{}"),
ToolCallItem(name="fn2", arguments='{"a": 1}'),
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2)
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1
assert tool_starts[1]["index"] == 2
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices
assert 1 in stop_indices
assert 2 in stop_indices

View File

@@ -7,15 +7,14 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
ForwarderCommand,
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -27,8 +26,9 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -127,19 +127,17 @@ async def test_master():
logger.info("wait for an instance")
while len(master.state.instances.keys()) == 0:
await anyio.sleep(0.001)
logger.info("inject a ChatCompletion Command")
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
command=(
ChatCompletion(
TextGeneration(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[
InputMessage(role="user", content="Hello, how are you?")
],
),
)
@@ -190,12 +188,10 @@ async def test_master():
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[InputMessage(role="user", content="Hello, how are you?")],
)
await master.shutdown()

View File

@@ -0,0 +1,48 @@
"""Tests for OpenAI Responses API wire types.
ResponsesRequest is the API wire type for the Responses endpoint.
The responses adapter converts it to TextGenerationTaskParams for the pipeline.
"""
import pydantic
import pytest
from exo.shared.types.common import ModelId
from exo.shared.types.openai_responses import (
ResponseInputMessage,
ResponsesRequest,
)
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1

View File

@@ -216,6 +216,8 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -25,6 +25,8 @@ from exo.shared.types.events import (
TestEvent,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TracesCollected,
TracesMerged,
)
from exo.shared.types.profiling import (
NodeIdentity,
@@ -55,7 +57,12 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
TestEvent()
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -2,6 +2,8 @@ import os
import sys
from pathlib import Path
from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
@@ -31,6 +33,14 @@ EXO_MODELS_DIR = (
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"
@@ -49,7 +59,10 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_TRACING_CACHE_DIR = EXO_CACHE_HOME / "traces"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated
from typing import Annotated, Any
import aiofiles
import aiofiles.os as aios
@@ -7,14 +7,47 @@ import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import (
AliasChoices,
BaseModel,
Field,
PositiveInt,
ValidationError,
field_validator,
model_validator,
)
from tomlkit.exceptions import TOMLKitError
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS, RESOURCES_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
# kinda ugly...
# TODO: load search path from config.toml
_csp = [Path(RESOURCES_DIR) / "inference_model_cards"]
if EXO_ENABLE_IMAGE_MODELS:
_csp.append(Path(RESOURCES_DIR) / "image_model_cards")
CARD_SEARCH_PATH = _csp
_card_cache: dict[ModelId, "ModelCard"] = {}
async def _refresh_card_cache():
for path in CARD_SEARCH_PATH:
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
async def get_model_cards() -> list["ModelCard"]:
if len(_card_cache) == 0:
await _refresh_card_cache()
return list(_card_cache.values())
class ModelTask(str, Enum):
@@ -48,28 +81,33 @@ class ModelCard(CamelCaseModel):
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
py = self.model_dump(exclude_none=True)
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
async def save_to_default_path(self):
await self.save(Path(RESOURCES_DIR) / (self.model_id.normalize() + ".toml"))
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if model_id not in _card_cache:
await _refresh_card_cache()
if (mc := _card_cache.get(model_id)) is not None:
return mc
return await ModelCard.fetch_from_hf(model_id)
@staticmethod
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
# TODO: failure if files do not exist
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
@@ -82,536 +120,13 @@ class ModelCard(CamelCaseModel):
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
)
await mc.save_to_default_path()
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"deepseek-v3.1-8bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2-thinking": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-8bit": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-bf16": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-70b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.2
"llama-3.2-1b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.3
"llama-3.3-70b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-fp16": ModelCard(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# qwen3
"qwen3-0.6b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-0.6b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"glm-4.5-air-bf16": ModelCard(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
def _generate_image_model_quant_variants(
# TODO: quantizing and dynamically creating model cards
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
@@ -691,27 +206,21 @@ def _generate_image_model_quant_variants(
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
validation_alias=AliasChoices(
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
)
)
@property
def supports_tensor(self) -> bool:
@@ -726,25 +235,27 @@ class ConfigData(BaseModel):
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
@model_validator(mode="before")
@classmethod
def defer_to_text_config(cls, data: dict[str, Any]):
text_config = data.get("text_config")
if text_config is None:
return data
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
for field in [
"architectures",
"hidden_size",
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
]:
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
data[field] = val
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
return data
async def get_config_data(model_id: ModelId) -> ConfigData:

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture
from pytest import LogCaptureFixture, mark
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

238
src/exo/shared/tracing.py Normal file
View File

@@ -0,0 +1,238 @@
from __future__ import annotations
import json
import time
from collections import defaultdict
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, final
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.worker.runner.bootstrap import logger
# Context variable to track the current trace category for hierarchical nesting
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
@final
@dataclass(frozen=True)
class TraceEvent:
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
@dataclass
class CategoryStats:
total_us: int = 0
count: int = 0
min_us: int = 0
max_us: int = 0
def add(self, duration_us: int) -> None:
if self.count == 0:
self.min_us = duration_us
self.max_us = duration_us
else:
self.min_us = min(self.min_us, duration_us)
self.max_us = max(self.max_us, duration_us)
self.total_us += duration_us
self.count += 1
@property
def avg_us(self) -> float:
return self.total_us / self.count if self.count > 0 else 0.0
@final
@dataclass
class TraceStats:
total_wall_time_us: int = 0
by_category: dict[str, CategoryStats] = field(default_factory=dict)
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
# Global trace buffer - each rank accumulates traces here
_trace_buffer: list[TraceEvent] = []
def _record_span(
name: str, start_us: int, duration_us: int, rank: int, category: str
) -> None:
_trace_buffer.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
@contextmanager
def trace(
name: str,
rank: int,
category: str = "compute",
) -> Generator[None, None, None]:
"""Context manager to trace any operation.
Nested traces automatically inherit the parent category, creating hierarchical
categories like "sync/compute" or "async/comms".
Args:
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
rank: This rank's ID
category: Category for grouping in trace viewer ("comm", "compute", "step")
Example:
with trace(f"sync {t}", rank, "sync"):
with trace("joint_blocks", rank, "compute"):
# Recorded with category "sync/compute"
hidden_states = some_computation(...)
"""
if not EXO_TRACING_ENABLED:
yield
return
# Combine with parent category if nested
parent = _current_category.get()
full_category = f"{parent}/{category}" if parent else category
# Set as current for nested traces
token = _current_category.set(full_category)
try:
start_us = int(time.time() * 1_000_000)
start_perf = time.perf_counter()
yield
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
_record_span(name, start_us, duration_us, rank, full_category)
finally:
_current_category.reset(token)
def get_trace_buffer() -> list[TraceEvent]:
return list(_trace_buffer)
def clear_trace_buffer() -> None:
_trace_buffer.clear()
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
trace_events: list[dict[str, object]] = []
for event in traces:
# Chrome trace format uses "X" for complete events (with duration)
chrome_event: dict[str, object] = {
"name": event.name,
"cat": event.category,
"ph": "X",
"ts": event.start_us,
"dur": event.duration_us,
"pid": 0,
"tid": event.rank,
"args": {"rank": event.rank},
}
trace_events.append(chrome_event)
ranks_seen = set(t.rank for t in traces)
for rank in ranks_seen:
trace_events.append(
{
"name": "thread_name",
"ph": "M", # Metadata event
"pid": 0,
"tid": rank,
"args": {"name": f"Rank {rank}"},
}
)
chrome_trace = {"traceEvents": trace_events}
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(chrome_trace, f, indent=2)
except OSError as e:
logger.warning("Failed to export trace to %s: %s", output_path, e)
def load_trace_file(path: Path) -> list[TraceEvent]:
with open(path) as f:
data = cast(dict[str, list[dict[str, object]]], json.load(f))
events = data.get("traceEvents", [])
traces: list[TraceEvent] = []
for event in events:
# Skip metadata events
if event.get("ph") == "M":
continue
name = str(event.get("name", ""))
category = str(event.get("cat", ""))
ts_value = event.get("ts", 0)
dur_value = event.get("dur", 0)
tid_value = event.get("tid", 0)
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
# Get rank from tid or args
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
args = event.get("args")
if isinstance(args, dict):
args_dict = cast(dict[str, object], args)
rank_from_args = args_dict.get("rank")
if isinstance(rank_from_args, (int, float, str)):
rank = int(rank_from_args)
traces.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
return traces
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
stats = TraceStats()
if not traces:
return stats
# Calculate wall time from earliest start to latest end
min_start = min(t.start_us for t in traces)
max_end = max(t.start_us + t.duration_us for t in traces)
stats.total_wall_time_us = max_end - min_start
# Initialize nested dicts
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
lambda: defaultdict(CategoryStats)
)
for event in traces:
# By category
by_category[event.category].add(event.duration_us)
# By rank and category
by_rank[event.rank][event.category].add(event.duration_us)
stats.by_category = dict(by_category)
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
return stats

View File

@@ -2,7 +2,6 @@ import time
from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -116,8 +115,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
class StreamingChoiceResponse(BaseModel):
@@ -170,8 +169,12 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
class StreamOptions(BaseModel):
include_usage: bool = False
class ChatCompletionRequest(BaseModel):
model: ModelId
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
@@ -184,15 +187,17 @@ class ChatCompletionTaskParams(BaseModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class BenchChatCompletionRequest(ChatCompletionRequest):
pass
@@ -276,28 +281,7 @@ class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
background: str | None = None
input_fidelity: float | None = None
mask: UploadFile | None = None
model: str
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
"""Internal task params for image-editing requests."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
@@ -366,3 +350,45 @@ class StartDownloadResponse(CamelCaseModel):
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
class TraceResponse(CamelCaseModel):
task_id: str
traces: list[TraceEventResponse]
class TraceCategoryStats(CamelCaseModel):
total_us: int
count: int
min_us: int
max_us: int
avg_us: float
class TraceRankStats(CamelCaseModel):
by_category: dict[str, TraceCategoryStats]
class TraceStatsResponse(CamelCaseModel):
task_id: str
total_wall_time_us: int
by_category: dict[str, TraceCategoryStats]
by_rank: dict[int, TraceRankStats]
class TraceListItem(CamelCaseModel):
task_id: str
created_at: str
file_size: int
class TraceListResponse(CamelCaseModel):
traces: list[TraceListItem]

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None

Some files were not shown because too many files have changed in this diff Show More